package edu.berkeley.nlp.lm.collections; import java.io.Serializable; import java.text.NumberFormat; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Random; import java.util.Set; import java.util.TreeSet; /** * A map from objects to doubles. Includes convenience methods for getting, * setting, and incrementing element counts. Objects not in the counter will * return a count of zero. The counter is backed by a HashMap .(unless specified * otherwise with the MapFactory constructor). * * @author lots of people */ public class Counter<E> implements Serializable { private static final long serialVersionUID = 1L; private final Map<E, Double> entries; private boolean dirty = true; private double cacheTotal = 0.0; private double defaultCount = 0.0; public double getDefaultCount() { return defaultCount; } public void setDefaultCount(final double deflt) { this.defaultCount = deflt; } /** * The elements in the counter. * * @return set of keys */ public Set<E> keySet() { return entries.keySet(); } public Set<Entry<E, Double>> entrySet() { return entries.entrySet(); } /** * The number of entries in the counter (not the total count -- use * totalCount() instead). */ public int size() { return entries.size(); } /** * True if there are no entries in the counter (false does not mean * totalCount > 0) */ public boolean isEmpty() { return size() == 0; } /** * Returns whether the counter contains the given key. Note that this is the * way to distinguish keys which are in the counter with count zero, and * those which are not in the counter (and will therefore return count zero * from getCount(). * * @param key * @return whether the counter contains the key */ public boolean containsKey(final E key) { return entries.containsKey(key); } /** * Get the count of the element, or zero if the element is not in the * counter. * * @param key * @return */ public double getCount(final E key) { final Double value = entries.get(key); if (value == null) return defaultCount; return value; } /** * I know, I know, this should be wrapped in a Distribution class, but it's * such a common use...why not. Returns the MLE prob. Assumes all the counts * are >= 0.0 and totalCount > 0.0. If the latter is false, return 0.0 (i.e. * 0/0 == 0) * * @author Aria * @param key * @return MLE prob of the key */ public double getProbability(final E key) { final double count = getCount(key); final double total = totalCount(); if (total < 0.0) { throw new RuntimeException("Can't call getProbability() with totalCount < 0.0"); } return total > 0.0 ? count / total : 0.0; } /** * Destructively normalize this Counter in place. */ public void normalize() { final double totalCount = totalCount(); for (final E key : keySet()) { setCount(key, getCount(key) / totalCount); } dirty = true; } /** * Set the count for the given key, clobbering any previous count. * * @param key * @param count */ public void setCount(final E key, final double count) { entries.put(key, count); dirty = true; } /** * Set the count for the given key if it is larger than the previous one; * * @param key * @param count */ public void put(final E key, final double count, final boolean keepHigher) { if (keepHigher && entries.containsKey(key)) { final double oldCount = entries.get(key); if (count > oldCount) { entries.put(key, count); } } else { entries.put(key, count); } dirty = true; } /** * Will return a sample from the counter, will throw exception if any of the * counts are < 0.0 or if the totalCount() <= 0.0 * * @return * * @author aria42 */ public E sample(final Random rand) { final double total = totalCount(); if (total <= 0.0) { throw new RuntimeException(String.format("Attempting to sample() with totalCount() %.3f\n", total)); } double sum = 0.0; final double r = rand.nextDouble(); for (final Map.Entry<E, Double> entry : entries.entrySet()) { final double count = entry.getValue(); final double frac = count / total; sum += frac; if (r < sum) { return entry.getKey(); } } throw new IllegalStateException("Shoudl've have returned a sample by now...."); } /** * Will return a sample from the counter, will throw exception if any of the * counts are < 0.0 or if the totalCount() <= 0.0 * * @return * * @author aria42 */ public E sample() { return sample(new Random()); } public void removeKey(final E key) { setCount(key, 0.0); dirty = true; removeKeyFromEntries(key); } /** * @param key */ protected void removeKeyFromEntries(final E key) { entries.remove(key); } /** * Set's the key's count to the maximum of the current count and val. Always * sets to val if key is not yet present. * * @param key * @param val */ public void setMaxCount(final E key, final double val) { final Double value = entries.get(key); if (value == null || val > value.doubleValue()) { setCount(key, val); dirty = true; } } /** * Set's the key's count to the minimum of the current count and val. Always * sets to val if key is not yet present. * * @param key * @param val */ public void setMinCount(final E key, final double val) { final Double value = entries.get(key); if (value == null || val < value.doubleValue()) { setCount(key, val); dirty = true; } } /** * Increment a key's count by the given amount. * * @param key * @param increment */ public double incrementCount(final E key, final double increment) { final double newVal = getCount(key) + increment; setCount(key, newVal); dirty = true; return newVal; } /** * Increment each element in a given collection by a given amount. */ public void incrementAll(final Collection<? extends E> collection, final double count) { for (final E key : collection) { incrementCount(key, count); } dirty = true; } public <T extends E> void incrementAll(final Counter<T> counter) { incrementAll(counter, 1.0); } public <T extends E> void incrementAll(final Counter<T> counter, final double scale) { for (final Entry<T, Double> entry : counter.entrySet()) { incrementCount(entry.getKey(), scale * entry.getValue()); } dirty = true; } /** * Finds the total of all counts in the counter. This implementation * iterates through the entire counter every time this method is called. * * @return the counter's total */ public double totalCount() { if (!dirty) { return cacheTotal; } double total = 0.0; for (final Map.Entry<E, Double> entry : entries.entrySet()) { total += entry.getValue(); } cacheTotal = total; dirty = false; return total; } public Collection<Entry<E, Double>> getEntriesSortedByIncreasingCount() { final List<Entry<E, Double>> sorted = new ArrayList<Entry<E, Double>>(entrySet()); Collections.sort(sorted, new EntryValueComparator(false)); return sorted; } public Collection<Entry<E, Double>> getEntriesSortedByDecreasingCount() { final List<Entry<E, Double>> sorted = new ArrayList<Entry<E, Double>>(entrySet()); Collections.sort(sorted, new EntryValueComparator(true)); return sorted; } /** * Finds the key with maximum count. This is a linear operation, and ties * are broken arbitrarily. * * @return a key with minumum count */ public E argMax() { double maxCount = Double.NEGATIVE_INFINITY; E maxKey = null; for (final Map.Entry<E, Double> entry : entries.entrySet()) { if (entry.getValue() > maxCount || maxKey == null) { maxKey = entry.getKey(); maxCount = entry.getValue(); } } return maxKey; } public double min() { return maxMinHelp(false); } public double max() { return maxMinHelp(true); } private double maxMinHelp(final boolean max) { double maxCount = max ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; for (final Map.Entry<E, Double> entry : entries.entrySet()) { if ((max && entry.getValue() > maxCount) || (!max && entry.getValue() < maxCount)) { maxCount = entry.getValue(); } } return maxCount; } /** * Returns a string representation with the keys ordered by decreasing * counts. * * @return string representation */ @Override public String toString() { return toStringSortedByKeys(); } public String toStringSortedByKeys() { final StringBuilder sb = new StringBuilder("["); final NumberFormat f = NumberFormat.getInstance(); f.setMaximumFractionDigits(5); int numKeysPrinted = 0; for (final E element : new TreeSet<E>(keySet())) { sb.append(element.toString()); sb.append(" : "); sb.append(f.format(getCount(element))); if (numKeysPrinted < size() - 1) sb.append(", "); numKeysPrinted++; } if (numKeysPrinted < size()) sb.append("..."); sb.append("]"); return sb.toString(); } public Counter() { entries = new HashMap<E, Double>(); } public Counter(final Counter<? extends E> counter) { this(); incrementAll(counter); } public Counter(final Collection<? extends E> collection) { this(); incrementAll(collection, 1.0); } public void pruneKeysBelowThreshold(final double cutoff) { final Iterator<E> it = entries.keySet().iterator(); while (it.hasNext()) { final E key = it.next(); final double val = entries.get(key); if (val < cutoff) { it.remove(); } } dirty = true; } public Set<Map.Entry<E, Double>> getEntrySet() { return entries.entrySet(); } public boolean isEqualTo(final Counter<E> counter) { boolean tmp = true; final Counter<E> bigger = counter.size() > size() ? counter : this; for (final E e : bigger.keySet()) { tmp &= counter.getCount(e) == getCount(e); } return tmp; } public static void main(final String[] args) { final Counter<String> counter = new Counter<String>(); System.out.println(counter); counter.incrementCount("planets", 7); System.out.println(counter); counter.incrementCount("planets", 1); System.out.println(counter); counter.setCount("suns", 1); System.out.println(counter); counter.setCount("aliens", 0); System.out.println(counter); System.out.println(counter.toString()); System.out.println("Total: " + counter.totalCount()); } public void clear() { entries.clear(); dirty = true; } /** * Sets all counts to the given value, but does not remove any keys */ public void setAllCounts(final double val) { for (final E e : keySet()) { setCount(e, val); } } public double dotProduct(final Counter<E> other) { double sum = 0.0; for (final Map.Entry<E, Double> entry : getEntrySet()) { final double otherCount = other.getCount(entry.getKey()); if (otherCount == 0.0) continue; final double value = entry.getValue(); if (value == 0.0) continue; sum += value * otherCount; } return sum; } public void scale(final double c) { for (final Map.Entry<E, Double> entry : getEntrySet()) { entry.setValue(entry.getValue() * c); } } public Counter<E> scaledClone(final double c) { final Counter<E> newCounter = new Counter<E>(); for (final Map.Entry<E, Double> entry : getEntrySet()) { newCounter.setCount(entry.getKey(), entry.getValue() * c); } return newCounter; } public Counter<E> difference(final Counter<E> counter) { final Counter<E> clone = new Counter<E>(this); for (final E key : counter.keySet()) { final double count = counter.getCount(key); clone.incrementCount(key, -1 * count); } return clone; } public Counter<E> toLogSpace() { final Counter<E> newCounter = new Counter<E>(this); for (final E key : newCounter.keySet()) { newCounter.setCount(key, Math.log(getCount(key))); } return newCounter; } public boolean approxEquals(final Counter<E> other, final double tol) { for (final E key : keySet()) { if (Math.abs(getCount(key) - other.getCount(key)) > tol) {// return false; } } for (final E key : other.keySet()) { if (Math.abs(getCount(key) - other.getCount(key)) > tol) { // return false; } } return true; } public void setDirty(final boolean dirty) { this.dirty = dirty; } public Iterable<Double> values() { return new Iterable<Double>() { @Override public Iterator<Double> iterator() { return new Iterator<Double>() { Iterator<Entry<E, Double>> entryIterator = entrySet().iterator(); @Override public boolean hasNext() { return entryIterator.hasNext(); } @Override public Double next() { return entryIterator.next().getValue(); } @Override public void remove() { entryIterator.remove(); } }; } }; } public void prune(final Set<E> toRemove) { for (final E e : toRemove) { removeKey(e); } } public void pruneExcept(final Set<E> toKeep) { final List<E> toRemove = new ArrayList<E>(); for (final E key : entries.keySet()) { if (!toKeep.contains(key)) toRemove.add(key); } for (final E e : toRemove) { removeKey(e); } } public static <L> Counter<L> absCounts(final Counter<L> counts) { final Counter<L> res = new Counter<L>(); for (final Map.Entry<L, Double> entry : counts.entrySet()) { res.incrementCount(entry.getKey(), Math.abs(entry.getValue())); } return res; } // Compare by value. public class EntryValueComparator implements Comparator<Entry<E, Double>> { /** * @param descending */ public EntryValueComparator(final boolean descending) { super(); this.descending = descending; } private final boolean descending; @Override public int compare(final Entry<E, Double> e1, final Entry<E, Double> e2) { return descending ? Double.compare(e2.getValue(), e1.getValue()) : Double.compare(e1.getValue(), e2.getValue()); } } public void putAll(final double d) { for (final Entry<E, Double> entry : entries.entrySet()) { setCount(entry.getKey(), d); } dirty = true; } }