package edu.berkeley.nlp.util; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Set; /** * A version of Counter that does not create Double objects. * * This should most certainly be rewritten. * * John */ public class FastCounter<E> implements Serializable { private static final long serialVersionUID = 1L; TDoubleMap<E> entries = new TDoubleMap<E>(); /** * The elements in the counter. * * @return set of keys */ public Set<E> keySet() { return entries.keySet(); } public void multAll(double dValue) { entries.multAll(dValue); } /** * The number of entries in the counter (not the total count -- use totalCount() instead). */ public int size() { return entries.size(); } /** * True if there are no entries in the counter (false does not mean totalCount > 0) */ public boolean isEmpty() { return size() == 0; } /** * Returns whether the counter contains the given key. Note that this is the * way to distinguish keys which are in the counter with count zero, and those * which are not in the counter (and will therefore return count zero from * getCount(). * * @param key * @return whether the counter contains the key */ public boolean containsKey(E key) { return entries.containsKey(key); } /** * Get the count of the element, or zero if the element is not in the * counter. * * @param key * @return */ public double getCount(E key) { return entries.get(key, 0.0); } /** * Destructively normalize this Counter in place. */ public void normalize() { entries.multAll(1.0 / entries.sum()); } /** * Set the count for the given key, clobbering any previous count. * * @param key * @param count */ public void setCount(E key, double count) { entries.put(key, count); } /** * Increment a key's count by the given amount. * * @param key * @param increment */ public void incrementCount(E key, double increment) { entries.incr(key, increment); } /** * Increment each element in a given collection by a given amount. */ public void incrementAll(Collection<? extends E> collection, double count) { for (E key : collection) { incrementCount(key, count); } } public <T extends E> void incrementAll(Counter<T> counter) { for (T key : counter.keySet()) { double count = counter.getCount(key); incrementCount(key, count); } } public <T extends E> void incrementAll(FastCounter<T> counter) { for (T key : counter.keySet()) { double count = counter.getCount(key); incrementCount(key, count); } } /** * Finds the total of all counts in the counter. This implementation iterates * through the entire counter every time this method is called. * * @return the counter's total */ public double totalCount() { return entries.sum(); } public List<E> getSortedKeys() { PriorityQueue<E> pq = this.asPriorityQueue(); List<E> keys = new ArrayList<E>(); while (pq.hasNext()) { keys.add(pq.next()); } return keys; } /** * Finds the key with maximum count. This is a linear operation, and ties are broken arbitrarily. * * @return a key with minumum count */ public E argMax() { return entries.argmax(); } public double min() { return maxMinHelp(false); } public double max() { return maxMinHelp(true); } private double maxMinHelp(boolean max) { double maxCount = max ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY; for (E key : entries.keySet()) { double val = entries.getSure(key); if ((max && val > maxCount) || (!max && val < maxCount)) { maxCount = val; } } return maxCount; } /** * Returns a string representation with the keys ordered by decreasing * counts. * * @return string representation */ @Override public String toString() { return toString(keySet().size()); } /** * Returns a string representation which includes no more than the * maxKeysToPrint elements with largest counts. * * @param maxKeysToPrint * @return partial string representation */ public String toString(int maxKeysToPrint) { return asPriorityQueue().toString(maxKeysToPrint, false); } /** * Builds a priority queue whose elements are the counter's elements, and * whose priorities are those elements' counts in the counter. */ public PriorityQueue<E> asPriorityQueue() { PriorityQueue<E> pq = new PriorityQueue<E>(entries.size()); for (E key : entries.keySet()) { pq.add(key, entries.getSure(key)); } return pq; } /** * Warning: all priorities are the negative of their counts in the counter here * @return */ public PriorityQueue<E> asMinPriorityQueue() { PriorityQueue<E> pq = new PriorityQueue<E>(entries.size()); for (E key : entries.keySet()) { pq.add(key, -1.0 * entries.getSure(key)); } return pq; } public void pruneKeysBelowThreshold(double cutoff) { Iterator<E> it = entries.keySet().iterator(); Set<E> remaining = new HashSet<E>(); while (it.hasNext()) { E key = it.next(); double val = entries.getSure(key); if (val >= cutoff) remaining.add(key); } entries = entries.restrict(remaining); } public void clear() { entries.gut(); } public void keepTopNKeys(int keepN) { keepKeysHelper(keepN, true); } public void keepBottomNKeys(int keepN) { keepKeysHelper(keepN, false); } private void keepKeysHelper(int keepN, boolean top) { Counter<E> tmp = new Counter<E>(); int n = 0; for (E e : Iterators.able(top ? asPriorityQueue() : asMinPriorityQueue())) { if (n <= keepN) tmp.setCount(e, getCount(e)); n++; } clear(); incrementAll(tmp); } /** * Sets all counts to the given value, but does not remove any keys */ public void setAllCounts(double val) { for (E e : keySet()) { setCount(e, val); } } public void switchToSortedList() { entries.switchToSortedList(); } public void switchToHashTable() { entries.switchToHashTable(); } public static void main(String[] args) { FastCounter<String> counter = new FastCounter<String>(); System.out.println(counter); counter.incrementCount("planets", 7); System.out.println(counter); counter.incrementCount("planets", 1); System.out.println(counter); counter.setCount("suns", 1); System.out.println(counter); counter.setCount("aliens", 0.5); System.out.println(counter); System.out.println(counter.toString(2)); System.out.println("Total: " + counter.totalCount()); counter.pruneKeysBelowThreshold(0.6); System.out.println(counter); System.out.println(counter.totalCount()); System.out.println("Waiting for profiler..."); try { Thread.sleep(5000); } catch (InterruptedException e) { throw new RuntimeException(e); } System.out.println("Done."); // Speed and memory tests FastCounter<Integer> fast = new FastCounter<Integer>(); Counter<Integer> baseline = new Counter<Integer>(); StopWatch watch = new StopWatch(); Random r = new Random(); int size = 50000000; watch.start(); for (int i = 0; i < size; i++) { fast.incrementCount(r.nextInt(size / 10), 1); } watch.stop(); System.out.println("Fast: " + watch.toString()); try { Thread.sleep(5000); System.out.println("Waiting for profiler..."); } catch (InterruptedException e) { throw new RuntimeException(e); } watch.reset(); watch.start(); fast.entries.switchToSortedList(); watch.stop(); System.out.println("Switching: " + watch.toString()); try { Thread.sleep(5000); System.out.println("Waiting for profiler..."); } catch (InterruptedException e) { throw new RuntimeException(e); } watch.reset(); watch.start(); for (int i = 0; i < size; i++) { baseline.incrementCount(r.nextInt(size / 10), 1); } watch.stop(); System.out.println("Baseline: " + watch.toString()); } }