// Stanford JavaNLP support classes // Copyright (c) 2004-2008 The Board of Trustees of // The Leland Stanford Junior University. All Rights Reserved. // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License // as published by the Free Software Foundation; either version 2 // of the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program; if not, write to the Free Software // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. // // For more information, bug reports, fixes, contact: // Christopher Manning // Dept of Computer Science, Gates 1A // Stanford CA 94305-9010 // USA // java-nlp-support@lists.stanford.edu // http://nlp.stanford.edu/software/ package edu.stanford.nlp.stats; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileWriter; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.io.PrintStream; import java.io.PrintWriter; import java.lang.reflect.Constructor; import java.text.NumberFormat; import java.util.AbstractCollection; import java.util.AbstractMap; import java.util.AbstractSet; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.Map.Entry; import java.util.function.Function; import java.util.regex.Pattern; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.math.SloppyMath; import edu.stanford.nlp.util.*; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.util.logging.PrettyLogger; import edu.stanford.nlp.util.logging.Redwood.RedwoodChannels; /** * Static methods for operating on a {@link Counter}. * <p> * All methods that change their arguments change the <i>first</i> argument * (only), and have "InPlace" in their name. This class also provides access to * Comparators that can be used to sort the keys or entries of this Counter by * the counts, in either ascending or descending order. * * @author Galen Andrew (galand@cs.stanford.edu) * @author Jeff Michels (jmichels@stanford.edu) * @author dramage * @author daniel cer (http://dmcer.net) * @author Christopher Manning * @author stefank (Optimized dot product) */ public class Counters { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(Counters.class); private static final double LOG_E_2 = Math.log(2.0); private Counters() {} // only static methods // // Log arithmetic operations // /** * Returns ArrayMath.logSum of the values in this counter. * * @param c Argument counter (which is not modified) * @return ArrayMath.logSum of the values in this counter. */ public static <E> double logSum(Counter<E> c) { return ArrayMath.logSum(ArrayMath.unbox(c.values())); } /** * Transform log space values into a probability distribution in place. On the * assumption that the values in the Counter are in log space, this method * calculates their sum, and then subtracts the log of their sum from each * element. That is, if a counter has keys c1, c2, c3 with values v1, v2, v3, * the value of c1 becomes v1 - log(e^v1 + e^v2 + e^v3). After this, e^v1 + * e^v2 + e^v3 = 1.0, so Counters.logSum(c) = 0.0 (approximately). * * @param c The Counter to log normalize in place */ @SuppressWarnings( { "UnnecessaryUnboxing" }) public static <E> void logNormalizeInPlace(Counter<E> c) { double logsum = logSum(c); // for (E key : c.keySet()) { // c.incrementCount(key, -logsum); // } // This should be faster for (Map.Entry<E, Double> e : c.entrySet()) { e.setValue(e.getValue().doubleValue() - logsum); } } // // Query operations // /** * Returns the value of the maximum entry in this counter. This is also the * L_infinity norm. An empty counter is given a max value of * Double.NEGATIVE_INFINITY. * * @param c The Counter to find the max of * @return The maximum value of the Counter */ public static <E> double max(Counter<E> c) { return max(c, Double.NEGATIVE_INFINITY); // note[gabor]: Should the default actually be 0 rather than negative_infinity? } /** * Returns the value of the maximum entry in this counter. This is also the * L_infinity norm. An empty counter is given a max value of * Double.NEGATIVE_INFINITY. * * @param c The Counter to find the max of * @param valueIfEmpty The value to return if this counter is empty (i.e., the maximum is not well defined. * @return The maximum value of the Counter */ public static <E> double max(Counter<E> c, double valueIfEmpty) { if (c.size() == 0) { return valueIfEmpty; } else { double max = Double.NEGATIVE_INFINITY; for (double v : c.values()) { max = Math.max(max, v); } return max; } } /** * Takes in a Collection of something and makes a counter, incrementing once * for each object in the collection. * * @param c The Collection to turn into a counter * @return The counter made out of the collection */ public static <E> Counter<E> asCounter(Collection<E> c) { Counter<E> count = new ClassicCounter<>(); for (E elem : c) { count.incrementCount(elem); } return count; } /** * Returns the value of the smallest entry in this counter. * * @param c The Counter (not modified) * @return The minimum value in the Counter */ public static <E> double min(Counter<E> c) { double min = Double.POSITIVE_INFINITY; for (double v : c.values()) { min = Math.min(min, v); } return min; } /** * Finds and returns the key in the Counter with the largest count. Returning * null if count is empty. * * @param c The Counter * @return The key in the Counter with the largest count. */ public static <E> E argmax(Counter<E> c) { return argmax(c, (x, y) -> 0, null); } /** * Finds and returns the key in this Counter with the smallest count. * * @param c The Counter * @return The key in the Counter with the smallest count. */ public static <E> E argmin(Counter<E> c) { double min = Double.POSITIVE_INFINITY; E argmin = null; for (E key : c.keySet()) { double count = c.getCount(key); if (argmin == null || count < min) { // || (count == min && tieBreaker.compare(key, argmin) < 0) min = count; argmin = key; } } return argmin; } /** * Finds and returns the key in the Counter with the largest count. Returning * null if count is empty. * * @param c The Counter * @param tieBreaker the tie breaker for when elements have the same value. * @return The key in the Counter with the largest count. */ public static <E> E argmax(Counter<E> c, Comparator<E> tieBreaker) { return argmax(c, tieBreaker, (E) null); } /** * Finds and returns the key in the Counter with the largest count. Returning * null if count is empty. * * @param c The Counter * @param tieBreaker the tie breaker for when elements have the same value. * @param defaultIfEmpty The value to return if the counter is empty. * @return The key in the Counter with the largest count. */ public static <E> E argmax(Counter<E> c, Comparator<E> tieBreaker, E defaultIfEmpty) { if (Thread.interrupted()) { // A good place to check for interrupts -- called from many annotators throw new RuntimeInterruptedException(); } if (c.size() == 0) { return defaultIfEmpty; } double max = Double.NEGATIVE_INFINITY; E argmax = null; for (E key : c.keySet()) { double count = c.getCount(key); if (argmax == null || count > max || (count == max && tieBreaker.compare(key, argmax) < 0)) { max = count; argmax = key; } } return argmax; } /** * Finds and returns the key in this Counter with the smallest count. * * @param c The Counter * @return The key in the Counter with the smallest count. */ public static <E> E argmin(Counter<E> c, Comparator<E> tieBreaker) { double min = Double.POSITIVE_INFINITY; E argmin = null; for (E key : c.keySet()) { double count = c.getCount(key); if (argmin == null || count < min || (count == min && tieBreaker.compare(key, argmin) < 0)) { min = count; argmin = key; } } return argmin; } /** * Returns the mean of all the counts (totalCount/size). * * @param c The Counter to find the mean of. * @return The mean of all the counts (totalCount/size). */ public static <E> double mean(Counter<E> c) { return c.totalCount() / c.size(); } public static <E> double standardDeviation(Counter<E> c) { double std = 0; double mean = c.totalCount() / c.size(); for (Map.Entry<E, Double> en : c.entrySet()) { std += (en.getValue() - mean) * (en.getValue() - mean); } return Math.sqrt(std / c.size()); } // // In-place arithmetic // /** * Sets each value of target to be target[k]+scale*arg[k] for all keys k in * target. * * @param target A Counter that is modified * @param arg The Counter whose contents are added to target * @param scale How the arg Counter is scaled before being added */ // TODO: Rewrite to use arg.entrySet() public static <E> void addInPlace(Counter<E> target, Counter<E> arg, double scale) { for (E key : arg.keySet()) { target.incrementCount(key, scale * arg.getCount(key)); } } /** * Sets each value of target to be target[k]+arg[k] for all keys k in arg. */ public static <E> void addInPlace(Counter<E> target, Counter<E> arg) { for (Map.Entry<E, Double> entry : arg.entrySet()) { double count = entry.getValue(); if (count != 0) { target.incrementCount(entry.getKey(), count); } } } /** * Sets each value of double[] target to be * target[idx.indexOf(k)]+a.getCount(k) for all keys k in arg */ public static <E> void addInPlace(double[] target, Counter<E> arg, Index<E> idx) { for (Map.Entry<E, Double> entry : arg.entrySet()) { target[idx.indexOf(entry.getKey())] += entry.getValue(); } } /** * For all keys (u,v) in arg1 and arg2, sets return[u,v] to be summation of both. * @param <T1> * @param <T2> */ public static <T1, T2> TwoDimensionalCounter<T1, T2> add(TwoDimensionalCounter<T1, T2> arg1, TwoDimensionalCounter<T1, T2> arg2) { TwoDimensionalCounter<T1, T2> add = new TwoDimensionalCounter<>(); Counters.addInPlace(add , arg1); Counters.addInPlace(add , arg2); return add; } /** * For all keys (u,v) in arg, sets target[u,v] to be target[u,v] + scale * * arg[u,v]. * * @param <T1> * @param <T2> */ public static <T1, T2> void addInPlace(TwoDimensionalCounter<T1, T2> target, TwoDimensionalCounter<T1, T2> arg, double scale) { for (T1 outer : arg.firstKeySet()) for (T2 inner : arg.secondKeySet()) { target.incrementCount(outer, inner, scale * arg.getCount(outer, inner)); } } /** * For all keys (u,v) in arg, sets target[u,v] to be target[u,v] + arg[u,v]. * * @param <T1> * @param <T2> */ public static <T1, T2> void addInPlace(TwoDimensionalCounter<T1, T2> target, TwoDimensionalCounter<T1, T2> arg) { for (T1 outer : arg.firstKeySet()) for (T2 inner : arg.secondKeySet()) { target.incrementCount(outer, inner, arg.getCount(outer, inner)); } } /** * Sets each value of target to be target[k]+ * value*(num-of-times-it-occurs-in-collection) if the key is present in the arg * collection. */ public static <E> void addInPlace(Counter<E> target, Collection<E> arg, double value) { for (E key : arg) { target.incrementCount(key, value); } } /** * For all keys (u,v) in target, sets target[u,v] to be target[u,v] + value * * @param <T1> * @param <T2> */ public static <T1, T2> void addInPlace(TwoDimensionalCounter<T1, T2> target, double value) { for (T1 outer : target.firstKeySet()){ addInPlace(target.getCounter(outer), value); } } /** * Sets each value of target to be target[k]+ * num-of-times-it-occurs-in-collection if the key is present in the arg * collection. */ public static <E> void addInPlace(Counter<E> target, Collection<E> arg) { for (E key : arg) { target.incrementCount(key, 1); } } /** * Increments all keys in a Counter by a specific value. */ public static <E> void addInPlace(Counter<E> target, double value) { for (E key : target.keySet()) { target.incrementCount(key, value); } } /** * Sets each value of target to be target[k]-arg[k] for all keys k in target. */ public static <E> void subtractInPlace(Counter<E> target, Counter<E> arg) { for (E key : arg.keySet()) { target.decrementCount(key, arg.getCount(key)); } } /** * Sets each value of double[] target to be * target[idx.indexOf(k)]-a.getCount(k) for all keys k in arg */ public static <E> void subtractInPlace(double[] target, Counter<E> arg, Index<E> idx) { for (Map.Entry<E, Double> entry : arg.entrySet()) { target[idx.indexOf(entry.getKey())] -= entry.getValue(); } } /** * Divides every non-zero count in target by the corresponding value in the * denominator Counter. Beware that this can give NaN values for zero counts * in the denominator counter! */ public static <E> void divideInPlace(Counter<E> target, Counter<E> denominator) { for (E key : target.keySet()) { target.setCount(key, target.getCount(key) / denominator.getCount(key)); } } /** * Multiplies every count in target by the corresponding value in the term * Counter. */ public static <E> void dotProductInPlace(Counter<E> target, Counter<E> term) { for (E key : target.keySet()) { target.setCount(key, target.getCount(key) * term.getCount(key)); } } /** * Divides each value in target by the given divisor, in place. * * @param target The values in this Counter will be changed throughout by the * multiplier * @param divisor The number by which to change each number in the Counter * @return The target Counter is returned (for easier method chaining) */ public static <E> Counter<E> divideInPlace(Counter<E> target, double divisor) { for (Entry<E, Double> entry : target.entrySet()) { target.setCount(entry.getKey(), entry.getValue() / divisor); } return target; } /** * Multiplies each value in target by the given multiplier, in place. * * @param target The values in this Counter will be multiplied by the * multiplier * @param multiplier The number by which to change each number in the Counter */ public static <E> Counter<E> multiplyInPlace(Counter<E> target, double multiplier) { for (Entry<E, Double> entry : target.entrySet()) { target.setCount(entry.getKey(), entry.getValue() * multiplier); } return target; } /** * Multiplies each value in target by the count of the key in mult, in place. Returns non zero entries * * @param target The counter * @param mult The counter you want to multiply with target */ public static <E> Counter<E> multiplyInPlace(Counter<E> target, Counter<E> mult) { for (Entry<E, Double> entry : target.entrySet()) { target.setCount(entry.getKey(), entry.getValue() * mult.getCount(entry.getKey())); } Counters.retainNonZeros(target); return target; } /** * Normalizes the target counter in-place, so the sum of the resulting values * equals 1. * * @param <E> Type of elements in Counter */ public static <E> void normalize(Counter<E> target) { divideInPlace(target, target.totalCount()); } /** * L1 normalize a counter. Return a counter that is a probability distribution, * so the sum of the resulting value equals 1. * * @param c The {@link Counter} to be L1 normalized. This counter is not * modified. * @return A new L1-normalized Counter based on c. */ public static <E, C extends Counter<E>> C asNormalizedCounter(C c) { return scale(c, 1.0 / c.totalCount()); } /** * Normalizes the target counter in-place, so the sum of the resulting values * equals 1. * * @param <E> Type of elements in TwoDimensionalCounter * @param <F> Type of elements in TwoDimensionalCounter */ public static <E, F> void normalize(TwoDimensionalCounter<E, F> target) { Counters.divideInPlace(target, target.totalCount()); } public static <E> void logInPlace(Counter<E> target) { for (E key : target.keySet()) { target.setCount(key, Math.log(target.getCount(key))); } } // // Selection Operators // /** * Delete 'top' and 'bottom' number of elements from the top and bottom * respectively */ public static <E> List<E> deleteOutofRange(Counter<E> c, int top, int bottom) { List<E> purgedItems = new ArrayList<>(); int numToPurge = top + bottom; if (numToPurge <= 0) { return purgedItems; } List<E> l = Counters.toSortedList(c); for (int i = 0; i < top; i++) { E item = l.get(i); purgedItems.add(item); c.remove(item); } int size = c.size(); for (int i = c.size() - 1; i >= (size - bottom); i--) { E item = l.get(i); purgedItems.add(item); c.remove(item); } return purgedItems; } /** * Removes all entries from c except for the top {@code num}. */ public static <E> void retainTop(Counter<E> c, int num) { int numToPurge = c.size() - num; if (numToPurge <= 0) { return; } List<E> l = Counters.toSortedList(c, true); for (int i = 0; i < numToPurge; i++) { c.remove(l.get(i)); } } /** * Removes all entries from c except for the top {@code num}. */ public static <E extends Comparable<E>> void retainTopKeyComparable(Counter<E> c, int num) { int numToPurge = c.size() - num; if (numToPurge <= 0) { return; } List<E> l = Counters.toSortedListKeyComparable(c); Collections.reverse(l); for (int i = 0; i < numToPurge; i++) { c.remove(l.get(i)); } } /** * Removes all entries from c except for the bottom {@code num}. */ public static <E> List<E> retainBottom(Counter<E> c, int num) { int numToPurge = c.size() - num; if (numToPurge <= 0) { return Generics.newArrayList(); } List<E> removed = new ArrayList<>(); List<E> l = Counters.toSortedList(c); for (int i = 0; i < numToPurge; i++) { E rem = l.get(i); removed.add(rem); c.remove(rem); } return removed; } /** * Removes all entries with 0 count in the counter, returning the set of * removed entries. */ public static <E> Set<E> retainNonZeros(Counter<E> counter) { Set<E> removed = Generics.newHashSet(); for (E key : counter.keySet()) { if (counter.getCount(key) == 0.0) { removed.add(key); } } for (E key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with counts below the given threshold, returning the * set of removed entries. * * @param counter The counter. * @param countThreshold * The minimum count for an entry to be kept. Entries (strictly) less * than this threshold are discarded. * @return The set of discarded entries. */ public static <E> Set<E> retainAbove(Counter<E> counter, double countThreshold) { Set<E> removed = Generics.newHashSet(); for (E key : counter.keySet()) { if (counter.getCount(key) < countThreshold) { removed.add(key); } } for (E key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with counts below the given threshold, returning the * set of removed entries. * * @param counter The counter. * @param countThreshold * The minimum count for an entry to be kept. Entries (strictly) less * than this threshold are discarded. * @return The set of discarded entries. */ public static <E1, E2> Set<Pair<E1, E2>> retainAbove( TwoDimensionalCounter<E1, E2> counter, double countThreshold) { Set<Pair<E1, E2>> removed = new HashSet<>(); for (Entry<E1, ClassicCounter<E2>> en : counter.entrySet()) { for (Entry<E2, Double> en2 : en.getValue().entrySet()) { if (counter.getCount(en.getKey(), en2.getKey()) < countThreshold) { removed.add(new Pair<>(en.getKey(), en2.getKey())); } } } for (Pair<E1, E2> key : removed) { counter.remove(key.first(), key.second()); } return removed; } /** * Removes all entries with counts above the given threshold, returning the * set of removed entries. * * @param counter The counter. * @param countMaxThreshold * The maximum count for an entry to be kept. Entries (strictly) more * than this threshold are discarded. * @return The set of discarded entries. */ public static <E> Counter<E> retainBelow(Counter<E> counter, double countMaxThreshold) { Counter<E> removed = new ClassicCounter<>(); for (E key : counter.keySet()) { double count = counter.getCount(key); if (counter.getCount(key) > countMaxThreshold) { removed.setCount(key, count); } } for (Entry<E, Double> key : removed.entrySet()) { counter.remove(key.getKey()); } return removed; } /** * Removes all entries with keys that does not match one of the given patterns. * * @param counter The counter. * @param matchPatterns pattern for key to match * @return The set of discarded entries. */ public static Set<String> retainMatchingKeys(Counter<String> counter, List<Pattern> matchPatterns) { Set<String> removed = Generics.newHashSet(); for (String key : counter.keySet()) { boolean matched = false; for (Pattern pattern : matchPatterns) { if (pattern.matcher(key).matches()) { matched = true; break; } } if (!matched) { removed.add(key); } } for (String key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with keys that does not match the given set of keys. * * @param counter The counter * @param matchKeys Keys to match * @return The set of discarded entries. */ public static<E> Set<E> retainKeys(Counter<E> counter, Collection<E> matchKeys) { Set<E> removed = Generics.newHashSet(); for (E key : counter.keySet()) { boolean matched = matchKeys.contains(key); if (!matched) { removed.add(key); } } for (E key : removed) { counter.remove(key); } return removed; } /** * Removes all entries with keys in the given collection * * @param <E> * @param counter * @param removeKeysCollection */ public static <E> void removeKeys(Counter<E> counter, Collection<E> removeKeysCollection) { for (E key : removeKeysCollection) counter.remove(key); } /** * Removes all entries with keys (first key set) in the given collection * * @param <E> * @param counter * @param removeKeysCollection */ public static <E, F> void removeKeys(TwoDimensionalCounter<E, F> counter, Collection<E> removeKeysCollection) { for (E key : removeKeysCollection) counter.remove(key); } /** * Returns the set of keys whose counts are at or above the given threshold. * This set may have 0 elements but will not be null. * * @param c The Counter to examine * @param countThreshold * Items equal to or above this number are kept * @return A (non-null) Set of keys whose counts are at or above the given * threshold. */ public static <E> Set<E> keysAbove(Counter<E> c, double countThreshold) { Set<E> keys = Generics.newHashSet(); for (E key : c.keySet()) { if (c.getCount(key) >= countThreshold) { keys.add(key); } } return (keys); } /** * Returns the set of keys whose counts are at or below the given threshold. * This set may have 0 elements but will not be null. */ public static <E> Set<E> keysBelow(Counter<E> c, double countThreshold) { Set<E> keys = Generics.newHashSet(); for (E key : c.keySet()) { if (c.getCount(key) <= countThreshold) { keys.add(key); } } return (keys); } /** * Returns the set of keys that have exactly the given count. This set may * have 0 elements but will not be null. */ public static <E> Set<E> keysAt(Counter<E> c, double count) { Set<E> keys = Generics.newHashSet(); for (E key : c.keySet()) { if (c.getCount(key) == count) { keys.add(key); } } return (keys); } // // Transforms // /** * Returns the counter with keys modified according to function F. Eager * evaluation. If two keys are same after the transformation, one of the values is randomly chosen (depending on how the keyset is traversed) */ public static <T1, T2> Counter<T2> transform(Counter<T1> c, Function<T1, T2> f) { Counter<T2> c2 = new ClassicCounter<>(); for (T1 key : c.keySet()) { c2.setCount(f.apply(key), c.getCount(key)); } return c2; } /** * Returns the counter with keys modified according to function F. If two keys are same after the transformation, their values get added up. */ public static <T1, T2> Counter<T2> transformWithValuesAdd(Counter<T1> c, Function<T1, T2> f) { Counter<T2> c2 = new ClassicCounter<>(); for (T1 key : c.keySet()) { c2.incrementCount(f.apply(key), c.getCount(key)); } return c2; } // // Conversion to other types // /** * Returns a comparator backed by this counter: two objects are compared by * their associated values stored in the counter. This comparator returns keys * by ascending numeric value. Note that this ordering is not fixed, but * depends on the mutable values stored in the Counter. Doing this comparison * does not depend on the type of the key, since it uses the numeric value, * which is always Comparable. * * @param counter The Counter whose values are used for ordering the keys * @return A Comparator using this ordering */ public static <E> Comparator<E> toComparator(final Counter<E> counter) { return (o1, o2) -> Double.compare(counter.getCount(o1), counter.getCount(o2)); } /** * Returns a comparator backed by this counter: two objects are compared by * their associated values stored in the counter. This comparator returns keys * by ascending numeric value. Note that this ordering is not fixed, but * depends on the mutable values stored in the Counter. Doing this comparison * does not depend on the type of the key, since it uses the numeric value, * which is always Comparable. * * @param counter The Counter whose values are used for ordering the keys * @return A Comparator using this ordering */ public static <E extends Comparable<E>> Comparator<E> toComparatorWithKeys(final Counter<E> counter) { return (o1, o2) -> { int res = Double.compare(counter.getCount(o1), counter.getCount(o2)); if (res == 0) { return o1.compareTo(o2); } else { return res; } }; } /** * Returns a comparator backed by this counter: two objects are compared by * their associated values stored in the counter. This comparator returns keys * by descending numeric value. Note that this ordering is not fixed, but * depends on the mutable values stored in the Counter. Doing this comparison * does not depend on the type of the key, since it uses the numeric value, * which is always Comparable. * * @param counter The Counter whose values are used for ordering the keys * @return A Comparator using this ordering */ public static <E> Comparator<E> toComparatorDescending(final Counter<E> counter) { return (o1, o2) -> Double.compare(counter.getCount(o2), counter.getCount(o1)); } /** * Returns a comparator suitable for sorting this Counter's keys or entries by * their respective value or magnitude (by absolute value). If * <tt>ascending</tt> is true, smaller magnitudes will be returned first, * otherwise higher magnitudes will be returned first. * <p/> * Sample usage: * * <pre> * Counter c = new Counter(); * // add to the counter... * List biggestAbsKeys = new ArrayList(c.keySet()); * Collections.sort(biggestAbsKeys, Counters.comparator(c, false, true)); * List smallestEntries = new ArrayList(c.entrySet()); * Collections.sort(smallestEntries, Counters.comparator(c, true, false)); * </pre> */ public static <E> Comparator<E> toComparator(final Counter<E> counter, final boolean ascending, final boolean useMagnitude) { return (o1, o2) -> { if (ascending) { if (useMagnitude) { return Double.compare(Math.abs(counter.getCount(o1)), Math.abs(counter.getCount(o2))); } else { return Double.compare(counter.getCount(o1), counter.getCount(o2)); } } else { // Descending if (useMagnitude) { return Double.compare(Math.abs(counter.getCount(o2)), Math.abs(counter.getCount(o1))); } else { return Double.compare(counter.getCount(o2), counter.getCount(o1)); } } }; } /** * A List of the keys in c, sorted from highest count to lowest. * So note that the default is descending! * * @return A List of the keys in c, sorted from highest count to lowest. */ public static <E> List<E> toSortedList(Counter<E> c) { return toSortedList(c, false); } /** * A List of the keys in c, sorted from highest count to lowest. * * @return A List of the keys in c, sorted from highest count to lowest. */ public static <E> List<E> toSortedList(Counter<E> c, boolean ascending) { List<E> l = new ArrayList<>(c.keySet()); Comparator<E> comp = ascending ? toComparator(c) : toComparatorDescending(c); Collections.sort(l, comp); return l; } /** * A List of the keys in c, sorted from highest count to lowest. * * @return A List of the keys in c, sorted from highest count to lowest. */ public static <E extends Comparable<E>> List<E> toSortedListKeyComparable(Counter<E> c) { List<E> l = new ArrayList<>(c.keySet()); Comparator<E> comp = toComparatorWithKeys(c); Collections.sort(l, comp); Collections.reverse(l); return l; } /** * Converts a counter to ranks; ranks start from 0 * * @return A counter where the count is the rank in the original counter */ public static <E> IntCounter<E> toRankCounter(Counter<E> c) { IntCounter<E> rankCounter = new IntCounter<>(); List<E> sortedList = toSortedList(c); for (int i = 0; i < sortedList.size(); i++) { rankCounter.setCount(sortedList.get(i), i); } return rankCounter; } /** * Converts a counter to tied ranks; ranks start from 1 * * @return A counter where the count is the rank in the original counter; when values are tied, the rank is the average of the ranks of the tied values */ public static <E> Counter<E> toTiedRankCounter(Counter<E> c) { Counter<E> rankCounter = new ClassicCounter<>(); List<Pair<E, Double>> sortedList = toSortedListWithCounts(c); int i = 0; Iterator<Pair<E, Double>> it = sortedList.iterator(); while(it.hasNext()) { Pair<E, Double> iEn = it.next(); double icount = iEn.second(); E iKey = iEn.first(); List<Integer> l = new ArrayList<>(); List<E> keys = new ArrayList<>(); l.add(i+1); keys.add(iKey); for(int j = i +1; j < sortedList.size(); j++){ Pair<E, Double> jEn = sortedList.get(j); if( icount == jEn.second()){ l.add(j+1); keys.add(jEn.first()); }else break; } if(l.size() > 1){ double sum = 0; for(Integer d: l) sum += d; double avgRank = sum/l.size(); for(int k = 0; k < l.size(); k++){ rankCounter.setCount(keys.get(k), avgRank); if(k != l.size()-1 && it.hasNext()) it.next(); i++; } }else{ rankCounter.setCount(iKey, i+1); i++; } } return rankCounter; } public static <E> List<Pair<E, Double>> toDescendingMagnitudeSortedListWithCounts(Counter<E> c) { List<E> keys = new ArrayList<>(c.keySet()); Collections.sort(keys, toComparator(c, false, true)); List<Pair<E, Double>> l = new ArrayList<>(keys.size()); for (E key : keys) { l.add(new Pair<>(key, c.getCount(key))); } return l; } /** * A List of the keys in c, sorted from highest count to lowest, paired with * counts * * @return A List of the keys in c, sorted from highest count to lowest. */ public static <E> List<Pair<E, Double>> toSortedListWithCounts(Counter<E> c) { List<Pair<E, Double>> l = new ArrayList<>(c.size()); for (E e : c.keySet()) { l.add(new Pair<>(e, c.getCount(e))); } // descending order Collections.sort(l, (a, b) -> Double.compare(b.second, a.second)); return l; } /** * A List of the keys in c, sorted by the given comparator, paired with * counts. * * @return A List of the keys in c, sorted from highest count to lowest. */ public static <E> List<Pair<E, Double>> toSortedListWithCounts(Counter<E> c, Comparator<Pair<E,Double>> comparator) { List<Pair<E, Double>> l = new ArrayList<>(c.size()); for (E e : c.keySet()) { l.add(new Pair<>(e, c.getCount(e))); } // descending order Collections.sort(l, comparator); return l; } /** * Returns a {@link edu.stanford.nlp.util.PriorityQueue} whose elements are * the keys of Counter c, and the score of each key in c becomes its priority. * * @param c Input Counter * @return A PriorityQueue where the count is a key's priority */ // TODO: rewrite to use entrySet() public static <E> edu.stanford.nlp.util.PriorityQueue<E> toPriorityQueue(Counter<E> c) { edu.stanford.nlp.util.PriorityQueue<E> queue = new BinaryHeapPriorityQueue<>(); for (E key : c.keySet()) { double count = c.getCount(key); queue.add(key, count); } return queue; } // // Other Utilities // /** * Returns a Counter that is the union of the two Counters passed in (counts * are added). * * @return A Counter that is the union of the two Counters passed in (counts * are added). */ @SuppressWarnings("unchecked") public static <E, C extends Counter<E>> C union(C c1, C c2) { C result = (C) c1.getFactory().create(); addInPlace(result, c1); addInPlace(result, c2); return result; } /** * Returns a counter that is the intersection of c1 and c2. If both c1 and c2 * contain a key, the min of the two counts is used. * * @return A counter that is the intersection of c1 and c2 */ public static <E> Counter<E> intersection(Counter<E> c1, Counter<E> c2) { Counter<E> result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { double count1 = c1.getCount(key); double count2 = c2.getCount(key); double minCount = (count1 < count2 ? count1 : count2); if (minCount > 0) { result.setCount(key, minCount); } } return result; } /** * Returns the Jaccard Coefficient of the two counters. Calculated as |c1 * intersect c2| / ( |c1| + |c2| - |c1 intersect c2| * * @return The Jaccard Coefficient of the two counters */ public static <E> double jaccardCoefficient(Counter<E> c1, Counter<E> c2) { double minCount = 0.0, maxCount = 0.0; for (E key : Sets.union(c1.keySet(), c2.keySet())) { double count1 = c1.getCount(key); double count2 = c2.getCount(key); minCount += (count1 < count2 ? count1 : count2); maxCount += (count1 > count2 ? count1 : count2); } return minCount / maxCount; } /** * Returns the product of c1 and c2. * * @return The product of c1 and c2. */ public static <E> Counter<E> product(Counter<E> c1, Counter<E> c2) { Counter<E> result = c1.getFactory().create(); for (E key : Sets.intersection(c1.keySet(), c2.keySet())) { result.setCount(key, c1.getCount(key) * c2.getCount(key)); } return result; } /** * Returns the product of c1 and c2. * * @return The product of c1 and c2. */ public static <E> double dotProduct(Counter<E> c1, Counter<E> c2) { double dotProd = 0.0; if (c1.size() > c2.size()) { Counter<E> tmpCnt = c1; c1 = c2; c2 = tmpCnt; } for (E key : c1.keySet()) { double count1 = c1.getCount(key); if (Double.isNaN(count1) || Double.isInfinite(count1)) { throw new RuntimeException("Counters.dotProduct infinite or NaN value for key: " + key + '\t' + c1.getCount(key) + '\t' + c2.getCount(key)); } if (count1 != 0.0) { double count2 = c2.getCount(key); if (Double.isNaN(count2) || Double.isInfinite(count2)) { throw new RuntimeException("Counters.dotProduct infinite or NaN value for key: " + key + '\t' + c1.getCount(key) + '\t' + c2.getCount(key)); } if (count2 != 0.0) { // this is the inner product dotProd += (count1 * count2); } } } return dotProd; } /** * Returns the product of Counter c and double[] a, using Index idx to map * entries in C onto a. * * @return The product of c and a. */ public static <E> double dotProduct(Counter<E> c, double[] a, Index<E> idx) { double dotProd = 0.0; for (Map.Entry<E, Double> entry : c.entrySet()) { int keyIdx = idx.indexOf(entry.getKey()); if (keyIdx >= 0) { dotProd += entry.getValue() * a[keyIdx]; } } return dotProd; } public static <E> double sumEntries(Counter<E> c1, Collection<E> entries) { double dotProd = 0.0; for (E entry : entries) { dotProd += c1.getCount(entry); } return dotProd; } public static <E> Counter<E> add(Counter<E> c1, Collection<E> c2) { Counter<E> result = c1.getFactory().create(); addInPlace(result, c1); for (E key : c2) { result.incrementCount(key, 1); } return result; } public static <E> Counter<E> add(Counter<E> c1, Counter<E> c2) { Counter<E> result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { result.setCount(key, c1.getCount(key) + c2.getCount(key)); } retainNonZeros(result); return result; } /** * increments every key in the counter by value */ public static <E> Counter<E> add(Counter<E> c1, double value) { Counter<E> result = c1.getFactory().create(); for (E key : c1.keySet()) { result.setCount(key, c1.getCount(key) + value); } return result; } /** * This method does not check entries for NAN or INFINITY values in the * doubles returned. It also only iterates over the counter with the smallest * number of keys to help speed up computation. Pair this method with * normalizing your counters before hand and you have a reasonably quick * implementation of cosine. * * @param <E> * @param c1 * @param c2 * @return The dot product of the two counter (as vectors) */ public static <E> double optimizedDotProduct(Counter<E> c1, Counter<E> c2) { int size1 = c1.size(); int size2 = c2.size(); if (size1 < size2) { return getDotProd(c1, c2); } else { return getDotProd(c2, c1); } } private static <E> double getDotProd(Counter<E> c1, Counter<E> c2) { double dotProd = 0.0; for (E key : c1.keySet()) { double count1 = c1.getCount(key); if (count1 != 0.0) { double count2 = c2.getCount(key); if (count2 != 0.0) dotProd += (count1 * count2); } } return dotProd; } /** * Returns |c1 - c2|. * * @return The difference between sets c1 and c2. */ public static <E> Counter<E> absoluteDifference(Counter<E> c1, Counter<E> c2) { Counter<E> result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { double newCount = Math.abs(c1.getCount(key) - c2.getCount(key)); if (newCount > 0) { result.setCount(key, newCount); } } return result; } /** * Returns c1 divided by c2. Note that this can create NaN if c1 has non-zero * counts for keys that c2 has zero counts. * * @return c1 divided by c2. */ public static <E> Counter<E> division(Counter<E> c1, Counter<E> c2) { Counter<E> result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { result.setCount(key, c1.getCount(key) / c2.getCount(key)); } return result; } /** * Returns c1 divided by c2. Safe - will not calculate scores for keys that are zero or that do not exist in c2 * * @return c1 divided by c2. */ public static <E> Counter<E> divisionNonNaN(Counter<E> c1, Counter<E> c2) { Counter<E> result = c1.getFactory().create(); for (E key : Sets.union(c1.keySet(), c2.keySet())) { if(c2.getCount(key) != 0) result.setCount(key, c1.getCount(key) / c2.getCount(key)); } return result; } /** * Calculates the entropy of the given counter (in bits). This method * internally uses normalized counts (so they sum to one), but the value * returned is meaningless if some of the counts are negative. * * @return The entropy of the given counter (in bits) */ public static <E> double entropy(Counter<E> c) { double entropy = 0.0; double total = c.totalCount(); for (E key : c.keySet()) { double count = c.getCount(key); if (count == 0) { continue; // 0.0 doesn't add entropy but may cause -Inf } count /= total; // use normalized count entropy -= count * (Math.log(count) / LOG_E_2); } return entropy; } /** * Note that this implementation doesn't normalize the "from" Counter. It * does, however, normalize the "to" Counter. Result is meaningless if any of * the counts are negative. * * @return The cross entropy of H(from, to) */ public static <E> double crossEntropy(Counter<E> from, Counter<E> to) { double tot2 = to.totalCount(); double result = 0.0; for (E key : from.keySet()) { double count1 = from.getCount(key); if (count1 == 0.0) { continue; } double count2 = to.getCount(key); double logFract = Math.log(count2 / tot2); if (logFract == Double.NEGATIVE_INFINITY) { return Double.NEGATIVE_INFINITY; // can't recover } result += count1 * (logFract / LOG_E_2); // express it in log base 2 } return result; } /** * Calculates the KL divergence between the two counters. That is, it * calculates KL(from || to). This method internally uses normalized counts * (so they sum to one), but the value returned is meaningless if any of the * counts are negative. In other words, how well can c1 be represented by c2. * if there is some value in c1 that gets zero prob in c2, then return * positive infinity. * * @return The KL divergence between the distributions */ public static <E> double klDivergence(Counter<E> from, Counter<E> to) { double result = 0.0; double tot = (from.totalCount()); double tot2 = (to.totalCount()); // System.out.println("tot is " + tot + " tot2 is " + tot2); for (E key : from.keySet()) { double num = (from.getCount(key)); if (num == 0) { continue; } num /= tot; double num2 = (to.getCount(key)); num2 /= tot2; // System.out.println("num is " + num + " num2 is " + num2); double logFract = Math.log(num / num2); if (logFract == Double.NEGATIVE_INFINITY) { return Double.NEGATIVE_INFINITY; // can't recover } result += num * (logFract / LOG_E_2); // express it in log base 2 } return result; } /** * Calculates the Jensen-Shannon divergence between the two counters. That is, * it calculates 1/2 [KL(c1 || avg(c1,c2)) + KL(c2 || avg(c1,c2))] . * This code assumes that the Counters have only non-negative values in them. * * @return The Jensen-Shannon divergence between the distributions */ public static <E> double jensenShannonDivergence(Counter<E> c1, Counter<E> c2) { // need to normalize the counters first before averaging them! Else buggy if not a probability distribution Counter<E> d1 = asNormalizedCounter(c1); Counter<E> d2 = asNormalizedCounter(c2); Counter<E> average = average(d1, d2); double kl1 = klDivergence(d1, average); double kl2 = klDivergence(d2, average); return (kl1 + kl2) / 2.0; } /** * Calculates the skew divergence between the two counters. That is, it * calculates KL(c1 || (c2*skew + c1*(1-skew))) . In other words, how well can * c1 be represented by a "smoothed" c2. * * @return The skew divergence between the distributions */ public static <E> double skewDivergence(Counter<E> c1, Counter<E> c2, double skew) { Counter<E> d1 = asNormalizedCounter(c1); Counter<E> d2 = asNormalizedCounter(c2); Counter<E> average = linearCombination(d2, skew, d1, (1.0 - skew)); return klDivergence(d1, average); } /** * Return the l2 norm (Euclidean vector length) of a Counter. * <i>Implementation note:</i> The method name favors legibility of the L over * the convention of using lowercase names for methods. * * @param c The Counter * @return Its length */ public static <E, C extends Counter<E>> double L2Norm(C c) { return Math.sqrt(Counters.sumSquares(c)); } /** * Return the sum of squares (squared L2 norm). * * @param c The Counter * @return the L2 norm of the values in c */ public static <E, C extends Counter<E>> double sumSquares(C c) { double lenSq = 0.0; for (E key : c.keySet()) { double count = c.getCount(key); lenSq += (count * count); } return lenSq; } /** * Return the L1 norm of a counter. <i>Implementation note:</i> The method * name favors legibility of the L over the convention of using lowercase * names for methods. * * @param c The Counter * @return Its length */ public static <E, C extends Counter<E>> double L1Norm(C c) { double sumAbs = 0.0; for (E key : c.keySet()) { double count = c.getCount(key); if (count != 0.0) { sumAbs += Math.abs(count); } } return sumAbs; } /** * L2 normalize a counter. * * @param c The {@link Counter} to be L2 normalized. This counter is not * modified. * @return A new l2-normalized Counter based on c. */ public static <E, C extends Counter<E>> C L2Normalize(C c) { return scale(c, 1.0 / L2Norm(c)); } /** * L2 normalize a counter in place. * * @param c The {@link Counter} to be L2 normalized. This counter is modified * @return the passed in counter l2-normalized */ public static <E> Counter<E> L2NormalizeInPlace(Counter<E> c) { return multiplyInPlace(c, 1.0 / L2Norm(c)); } /** * For counters with large # of entries, this scales down each entry in the * sum, to prevent an extremely large sum from building up and overwhelming * the max double. This may also help reduce error by preventing loss of SD's * with extremely large values. * * @param <E> * @param <C> */ public static <E, C extends Counter<E>> double saferL2Norm(C c) { double maxVal = 0.0; for (E key : c.keySet()) { double value = Math.abs(c.getCount(key)); if (value > maxVal) maxVal = value; } double sqrSum = 0.0; for (E key : c.keySet()) { double count = c.getCount(key); sqrSum += Math.pow(count / maxVal, 2); } return maxVal * Math.sqrt(sqrSum); } /** * L2 normalize a counter, using the "safer" L2 normalizer. * * @param c The {@link Counter} to be L2 normalized. This counter is not * modified. * @return A new L2-normalized Counter based on c. */ public static <E, C extends Counter<E>> C saferL2Normalize(C c) { return scale(c, 1.0 / saferL2Norm(c)); } public static <E> double cosine(Counter<E> c1, Counter<E> c2) { double dotProd = 0.0; double lsq1 = 0.0; double lsq2 = 0.0; for (E key : c1.keySet()) { double count1 = c1.getCount(key); if (count1 != 0.0) { lsq1 += (count1 * count1); double count2 = c2.getCount(key); if (count2 != 0.0) { // this is the inner product dotProd += (count1 * count2); } } } for (E key : c2.keySet()) { double count2 = c2.getCount(key); if (count2 != 0.0) { lsq2 += (count2 * count2); } } if (lsq1 != 0.0 && lsq2 != 0.0) { double denom = (Math.sqrt(lsq1) * Math.sqrt(lsq2)); return dotProd / denom; } return 0.0; } /** * Returns a new Counter with counts averaged from the two given Counters. The * average Counter will contain the union of keys in both source Counters, and * each count will be the average of the two source counts for that key, where * as usual a missing count in one Counter is treated as count 0. * * @return A new counter with counts that are the mean of the resp. counts in * the given counters. */ public static <E> Counter<E> average(Counter<E> c1, Counter<E> c2) { Counter<E> average = c1.getFactory().create(); Set<E> allKeys = Generics.newHashSet(c1.keySet()); allKeys.addAll(c2.keySet()); for (E key : allKeys) { average.setCount(key, (c1.getCount(key) + c2.getCount(key)) * 0.5); } return average; } /** * Returns a Counter which is a weighted average of c1 and c2. Counts from c1 * are weighted with weight w1 and counts from c2 are weighted with w2. */ public static <E> Counter<E> linearCombination(Counter<E> c1, double w1, Counter<E> c2, double w2) { Counter<E> result = c1.getFactory().create(); for (E o : c1.keySet()) { result.incrementCount(o, c1.getCount(o) * w1); } for (E o : c2.keySet()) { result.incrementCount(o, c2.getCount(o) * w2); } return result; } public static <T1, T2> double pointwiseMutualInformation(Counter<T1> var1Distribution, Counter<T2> var2Distribution, Counter<Pair<T1, T2>> jointDistribution, Pair<T1, T2> values) { double var1Prob = var1Distribution.getCount(values.first); double var2Prob = var2Distribution.getCount(values.second); double jointProb = jointDistribution.getCount(values); double pmi = Math.log(jointProb) - Math.log(var1Prob) - Math.log(var2Prob); return pmi / LOG_E_2; } /** * Calculate h-Index (Hirsch, 2005) of an author. * * A scientist has index h if h of their Np papers have at least h citations * each, and the other (Np − h) papers have at most h citations each. * * @param citationCounts * Citation counts for each of the articles written by the author. * The keys can be anything, but the values should be integers. * @return The h-Index of the author. */ public static <E> int hIndex(Counter<E> citationCounts) { Counter<Integer> countCounts = new ClassicCounter<>(); for (double value : citationCounts.values()) { for (int i = 0; i <= value; ++i) { countCounts.incrementCount(i); } } List<Integer> citationCountValues = CollectionUtils.sorted(countCounts.keySet()); Collections.reverse(citationCountValues); for (int citationCount : citationCountValues) { double occurrences = countCounts.getCount(citationCount); if (occurrences >= citationCount) { return citationCount; } } return 0; } @SuppressWarnings("unchecked") public static <E, C extends Counter<E>> C perturbCounts(C c, Random random, double p) { C result = (C) c.getFactory().create(); for (E key : c.keySet()) { double count = c.getCount(key); double noise = -Math.log(1.0 - random.nextDouble()); // inverse of CDF for // exponential // distribution // log.info("noise=" + noise); double perturbedCount = count + noise * p; result.setCount(key, perturbedCount); } return result; } /** * Great for debugging. * */ public static <E> void printCounterComparison(Counter<E> a, Counter<E> b) { printCounterComparison(a, b, System.err); } /** * Great for debugging. * */ public static <E> void printCounterComparison(Counter<E> a, Counter<E> b, PrintStream out) { printCounterComparison(a, b, new PrintWriter(out, true)); } /** * Prints one or more lines (with a newline at the end) describing the * difference between the two Counters. Great for debugging. * */ public static <E> void printCounterComparison(Counter<E> a, Counter<E> b, PrintWriter out) { if (a.equals(b)) { out.println("Counters are equal."); return; } for (E key : a.keySet()) { double aCount = a.getCount(key); double bCount = b.getCount(key); if (Math.abs(aCount - bCount) > 1e-5) { out.println("Counters differ on key " + key + '\t' + a.getCount(key) + " vs. " + b.getCount(key)); } } // left overs Set<E> rest = Generics.newHashSet(b.keySet()); rest.removeAll(a.keySet()); for (E key : rest) { double aCount = a.getCount(key); double bCount = b.getCount(key); if (Math.abs(aCount - bCount) > 1e-5) { out.println("Counters differ on key " + key + '\t' + a.getCount(key) + " vs. " + b.getCount(key)); } } } public static <E> Counter<Double> getCountCounts(Counter<E> c) { Counter<Double> result = new ClassicCounter<>(); for (double v : c.values()) { result.incrementCount(v); } return result; } /** * Returns a new Counter which is scaled by the given scale factor. * * @param c The counter to scale. It is not changed * @param s The constant to scale the counter by * @return A new Counter which is the argument scaled by the given scale * factor. */ @SuppressWarnings("unchecked") public static <E, C extends Counter<E>> C scale(C c, double s) { C scaled = (C) c.getFactory().create(); for (E key : c.keySet()) { scaled.setCount(key, c.getCount(key) * s); } return scaled; } /** * Returns a new Counter which is the input counter with log tf scaling * * @param c The counter to scale. It is not changed * @param base The base of the logarithm used for tf scaling by 1 + log tf * @return A new Counter which is the argument scaled by the given scale * factor. */ @SuppressWarnings("unchecked") public static <E, C extends Counter<E>> C tfLogScale(C c, double base) { C scaled = (C) c.getFactory().create(); for (E key : c.keySet()) { double cnt = c.getCount(key); double scaledCnt = 0.0; if (cnt > 0) { scaledCnt = 1.0 + SloppyMath.log(cnt, base); } scaled.setCount(key, scaledCnt); } return scaled; } public static <E extends Comparable<E>> void printCounterSortedByKeys(Counter<E> c) { List<E> keyList = new ArrayList<>(c.keySet()); Collections.sort(keyList); for (E o : keyList) { System.out.println(o + ":" + c.getCount(o)); } } /** * Loads a Counter from a text file. File must have the format of one * key/count pair per line, separated by whitespace. * * @param filename The path to the file to load the Counter from * @param c The Class to instantiate each member of the set. Must have a * String constructor. * @return The counter loaded from the file. */ public static <E> ClassicCounter<E> loadCounter(String filename, Class<E> c) throws RuntimeException { ClassicCounter<E> counter = new ClassicCounter<>(); loadIntoCounter(filename, c, counter); return counter; } /** * Loads a Counter from a text file. File must have the format of one * key/count pair per line, separated by whitespace. * * @param filename The path to the file to load the Counter from * @param c The Class to instantiate each member of the set. Must have a * String constructor. * @return The counter loaded from the file. */ public static <E> IntCounter<E> loadIntCounter(String filename, Class<E> c) throws Exception { IntCounter<E> counter = new IntCounter<>(); loadIntoCounter(filename, c, counter); return counter; } /** * Loads a file into an GenericCounter. */ private static <E> void loadIntoCounter(String filename, Class<E> c, Counter<E> counter) throws RuntimeException { try { Constructor<E> m = c.getConstructor(String.class); BufferedReader in = IOUtils.getBufferedFileReader(filename); for (String line; (line = in.readLine()) != null;) { String[] tokens = line.trim().split("\\s+"); if (tokens.length != 2) throw new RuntimeException(); double value = Double.parseDouble(tokens[1]); counter.setCount(m.newInstance(tokens[0]), value); } in.close(); } catch (Exception e) { throw new RuntimeException(e); } } /** * Saves a Counter as one key/count pair per line separated by white space to * the given OutputStream. Does not close the stream. */ public static <E> void saveCounter(Counter<E> c, OutputStream stream) { PrintStream out = new PrintStream(stream); for (E key : c.keySet()) { out.println(key + " " + c.getCount(key)); } } /** * Saves a Counter to a text file. Counter written as one key/count pair per * line, separated by whitespace. */ public static <E> void saveCounter(Counter<E> c, String filename) throws IOException { FileOutputStream fos = new FileOutputStream(filename); saveCounter(c, fos); fos.close(); } public static <T1, T2> TwoDimensionalCounter<T1, T2> load2DCounter(String filename, Class<T1> t1, Class<T2> t2) throws RuntimeException { try { TwoDimensionalCounter<T1, T2> tdc = new TwoDimensionalCounter<>(); loadInto2DCounter(filename, t1, t2, tdc); return tdc; } catch (Exception e) { throw new RuntimeException(e); } } public static <T1, T2> void loadInto2DCounter(String filename, Class<T1> t1, Class<T2> t2, TwoDimensionalCounter<T1, T2> tdc) throws RuntimeException { try { Constructor<T1> m1 = t1.getConstructor(String.class); Constructor<T2> m2 = t2.getConstructor(String.class); BufferedReader in = IOUtils.getBufferedFileReader(filename);// new // BufferedReader(new // FileReader(filename)); for (String line; (line = in.readLine()) != null;) { String[] tuple = line.trim().split("\t"); String outer = tuple[0]; String inner = tuple[1]; String valStr = tuple[2]; tdc.setCount(m1.newInstance(outer.trim()), m2.newInstance(inner.trim()), Double.parseDouble(valStr.trim())); } in.close(); } catch (Exception e) { throw new RuntimeException(e); } } public static <T1, T2> void loadIncInto2DCounter(String filename, Class<T1> t1, Class<T2> t2, TwoDimensionalCounterInterface<T1, T2> tdc) throws RuntimeException { try { Constructor<T1> m1 = t1.getConstructor(String.class); Constructor<T2> m2 = t2.getConstructor(String.class); BufferedReader in = IOUtils.getBufferedFileReader(filename);// new // BufferedReader(new // FileReader(filename)); for (String line; (line = in.readLine()) != null;) { String[] tuple = line.trim().split("\t"); String outer = tuple[0]; String inner = tuple[1]; String valStr = tuple[2]; tdc.incrementCount(m1.newInstance(outer.trim()), m2.newInstance(inner.trim()), Double.parseDouble(valStr.trim())); } in.close(); } catch (Exception e) { throw new RuntimeException(e); } } public static <T1, T2> void save2DCounter(TwoDimensionalCounter<T1, T2> tdc, String filename) throws IOException { PrintWriter out = new PrintWriter(new FileWriter(filename)); for (T1 outer : tdc.firstKeySet()) { for (T2 inner : tdc.secondKeySet()) { out.println(outer + "\t" + inner + '\t' + tdc.getCount(outer, inner)); } } out.close(); } public static <T1, T2> void save2DCounterSorted(TwoDimensionalCounterInterface<T1, T2> tdc, String filename) throws IOException { PrintWriter out = new PrintWriter(new FileWriter(filename)); for (T1 outer : tdc.firstKeySet()) { Counter<T2> c = tdc.getCounter(outer); List<T2> keys = Counters.toSortedList(c); for (T2 inner : keys) { out.println(outer + "\t" + inner + '\t' + c.getCount(inner)); } } out.close(); } /** * Serialize a counter into an efficient string TSV * @param c The counter to serialize * @param filename The file to serialize to * @param minMagnitude Ignore values under this magnitude * @throws IOException * * @see Counters#deserializeStringCounter(String) */ public static void serializeStringCounter(Counter<String> c, String filename, double minMagnitude) throws IOException { PrintWriter writer = IOUtils.getPrintWriter(filename); for (Entry<String, Double> entry : c.entrySet()) { if (Math.abs(entry.getValue()) < minMagnitude) { continue; } Triple<Boolean, Long, Integer> parts = SloppyMath.segmentDouble(entry.getValue()); writer.println( entry.getKey().replace('\t', 'ߝ') + "\t" + (parts.first ? '-' : '+') + "\t" + parts.second + "\t" + parts.third ); } writer.close(); } /** @see Counters#serializeStringCounter(Counter, String, double) */ public static void serializeStringCounter(Counter<String> c, String filename) throws IOException { serializeStringCounter(c, filename, 0.0); } /** * Read a Counter from a serialized file * @param filename The file to read from * * @see Counters#serializeStringCounter(Counter, String, double) */ public static ClassicCounter<String> deserializeStringCounter(String filename) throws IOException { String[] fields = new String[4]; BufferedReader reader = IOUtils.readerFromString(filename); String line; ClassicCounter<String> counts = new ClassicCounter<>(1000000); while ( (line = reader.readLine()) != null) { StringUtils.splitOnChar(fields, line, '\t'); long mantissa = SloppyMath.parseInt(fields[2]); int exponent = (int) SloppyMath.parseInt(fields[3]); double value = SloppyMath.parseDouble(fields[1].equals("-"), mantissa, exponent); counts.setCount(fields[0], value); } return counts; } public static <T> void serializeCounter(Counter<T> c, String filename) throws IOException { // serialize to file ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(filename))); out.writeObject(c); out.close(); } public static <T> ClassicCounter<T> deserializeCounter(String filename) throws Exception { // reconstitute ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream(filename))); ClassicCounter<T> c = ErasureUtils.uncheckedCast(in.readObject()); in.close(); return c; } /** * Returns a string representation of a Counter, displaying the keys and their * counts in decreasing order of count. At most k keys are displayed. * * Note that this method subsumes many of the other toString methods, e.g.: * * toString(c, k) and toBiggestValuesFirstString(c, k) => toSortedString(c, k, * "%s=%f", ", ", "[%s]") * * toVerticalString(c, k) => toSortedString(c, k, "%2$g\t%1$s", "\n", "%s\n") * * @param counter A Counter. * @param k The number of keys to include. Use Integer.MAX_VALUE to include * all keys. * @param itemFormat * The format string for key/count pairs, where the key is first and * the value is second. To display the value first, use argument * indices, e.g. "%2$f %1$s". * @param joiner The string used between pairs of key/value strings. * @param wrapperFormat * The format string for wrapping text around the joined items, where * the joined item string value is "%s". * @return The top k values from the Counter, formatted as specified. */ public static <T> String toSortedString(Counter<T> counter, int k, String itemFormat, String joiner, String wrapperFormat) { PriorityQueue<T> queue = toPriorityQueue(counter); List<String> strings = new ArrayList<>(); for (int rank = 0; rank < k && !queue.isEmpty(); ++rank) { T key = queue.removeFirst(); double value = counter.getCount(key); strings.add(String.format(itemFormat, key, value)); } return String.format(wrapperFormat, StringUtils.join(strings, joiner)); } /** * Returns a string representation of a Counter, displaying the keys and their * counts in decreasing order of count. At most k keys are displayed. * * @param counter A Counter. * @param k * The number of keys to include. Use Integer.MAX_VALUE to include * all keys. * @param itemFormat * The format string for key/count pairs, where the key is first and * the value is second. To display the value first, use argument * indices, e.g. "%2$f %1$s". * @param joiner * The string used between pairs of key/value strings. * @return The top k values from the Counter, formatted as specified. */ public static <T> String toSortedString(Counter<T> counter, int k, String itemFormat, String joiner) { return toSortedString(counter, k, itemFormat, joiner, "%s"); } /** * Returns a string representation of a Counter, where (key, value) pairs are * sorted by key, and formatted as specified. * * @param counter The Counter. * @param itemFormat * The format string for key/count pairs, where the key is first and * the value is second. To display the value first, use argument * indices, e.g. "%2$f %1$s". * @param joiner * The string used between pairs of key/value strings. * @param wrapperFormat * The format string for wrapping text around the joined items, where * the joined item string value is "%s". * @return The Counter, formatted as specified. */ public static <T extends Comparable<T>> String toSortedByKeysString(Counter<T> counter, String itemFormat, String joiner, String wrapperFormat) { List<String> strings = new ArrayList<>(); for (T key : CollectionUtils.sorted(counter.keySet())) { strings.add(String.format(itemFormat, key, counter.getCount(key))); } return String.format(wrapperFormat, StringUtils.join(strings, joiner)); } /** * Returns a string representation which includes no more than the * maxKeysToPrint elements with largest counts. If maxKeysToPrint is * non-positive, all elements are printed. * * @param counter The Counter * @param maxKeysToPrint Max keys to print * @return A partial string representation */ public static <E> String toString(Counter<E> counter, int maxKeysToPrint) { return Counters.toPriorityQueue(counter).toString(maxKeysToPrint); } public static <E> String toString(Counter<E> counter, NumberFormat nf) { StringBuilder sb = new StringBuilder(); sb.append('{'); List<E> list = ErasureUtils.sortedIfPossible(counter.keySet()); // */ for (Iterator<E> iter = list.iterator(); iter.hasNext();) { E key = iter.next(); sb.append(key); sb.append('='); sb.append(nf.format(counter.getCount(key))); if (iter.hasNext()) { sb.append(", "); } } sb.append('}'); return sb.toString(); } /** * Pretty print a Counter. This one has more flexibility in formatting, and * doesn't sort the keys. */ public static <E> String toString(Counter<E> counter, NumberFormat nf, String preAppend, String postAppend, String keyValSeparator, String itemSeparator) { StringBuilder sb = new StringBuilder(); sb.append(preAppend); // List<E> list = new ArrayList<E>(map.keySet()); // try { // Collections.sort(list); // see if it can be sorted // } catch (Exception e) { // } for (Iterator<E> iter = counter.keySet().iterator(); iter.hasNext();) { E key = iter.next(); double d = counter.getCount(key); sb.append(key); sb.append(keyValSeparator); sb.append(nf.format(d)); if (iter.hasNext()) { sb.append(itemSeparator); } } sb.append(postAppend); return sb.toString(); } public static <E> String toBiggestValuesFirstString(Counter<E> c) { return toPriorityQueue(c).toString(); } // TODO this method seems badly written. It should exploit topK printing of PriorityQueue public static <E> String toBiggestValuesFirstString(Counter<E> c, int k) { PriorityQueue<E> pq = toPriorityQueue(c); PriorityQueue<E> largestK = new BinaryHeapPriorityQueue<>(); // TODO: Is there any reason the original (commented out) line is better // than the one replacing it? // while (largestK.size() < k && ((Iterator<E>)pq).hasNext()) { while (largestK.size() < k && !pq.isEmpty()) { double firstScore = pq.getPriority(pq.getFirst()); E first = pq.removeFirst(); largestK.changePriority(first, firstScore); } return largestK.toString(); } public static <T> String toBiggestValuesFirstString(Counter<Integer> c, int k, Index<T> index) { PriorityQueue<Integer> pq = toPriorityQueue(c); PriorityQueue<T> largestK = new BinaryHeapPriorityQueue<>(); // while (largestK.size() < k && ((Iterator)pq).hasNext()) { //same as above while (largestK.size() < k && !pq.isEmpty()) { double firstScore = pq.getPriority(pq.getFirst()); int first = pq.removeFirst(); largestK.changePriority(index.get(first), firstScore); } return largestK.toString(); } public static <E> String toVerticalString(Counter<E> c) { return toVerticalString(c, Integer.MAX_VALUE); } public static <E> String toVerticalString(Counter<E> c, int k) { return toVerticalString(c, k, "%g\t%s", false); } public static <E> String toVerticalString(Counter<E> c, String fmt) { return toVerticalString(c, Integer.MAX_VALUE, fmt, false); } public static <E> String toVerticalString(Counter<E> c, int k, String fmt) { return toVerticalString(c, k, fmt, false); } /** * Returns a {@code String} representation of the {@code k} keys * with the largest counts in the given {@link Counter}, using the given * format string. * * @param c A Counter * @param k How many keys to print * @param fmt A format string, such as "%.0f\t%s" (do not include final "%n"). * If swap is false, you will get val, key as arguments, if true, key, val. * @param swap Whether the count should appear after the key */ public static <E> String toVerticalString(Counter<E> c, int k, String fmt, boolean swap) { PriorityQueue<E> q = Counters.toPriorityQueue(c); List<E> sortedKeys = q.toSortedList(); StringBuilder sb = new StringBuilder(); int i = 0; for (Iterator<E> keyI = sortedKeys.iterator(); keyI.hasNext() && i < k; i++) { E key = keyI.next(); double val = q.getPriority(key); if (swap) { sb.append(String.format(fmt, key, val)); } else { sb.append(String.format(fmt, val, key)); } if (keyI.hasNext()) { sb.append('\n'); } } return sb.toString(); } /** * * @return Returns the maximum element of c that is within the restriction * Collection */ public static <E> E restrictedArgMax(Counter<E> c, Collection<E> restriction) { E maxKey = null; double max = Double.NEGATIVE_INFINITY; for (E key : restriction) { double count = c.getCount(key); if (count > max) { max = count; maxKey = key; } } return maxKey; } public static <T> Counter<T> toCounter(double[] counts, Index<T> index) { if (index.size() < counts.length) throw new IllegalArgumentException("Index not large enough to name all the array elements!"); Counter<T> c = new ClassicCounter<>(); for (int i = 0; i < counts.length; i++) { if (counts[i] != 0.0) c.setCount(index.get(i), counts[i]); } return c; } /** * Turns the given map and index into a counter instance. For each entry in * counts, its key is converted to a counter key via lookup in the given * index. */ public static <E> Counter<E> toCounter(Map<Integer, ? extends Number> counts, Index<E> index) { Counter<E> counter = new ClassicCounter<>(); for (Map.Entry<Integer, ? extends Number> entry : counts.entrySet()) { counter.setCount(index.get(entry.getKey()), entry.getValue().doubleValue()); } return counter; } /** * Convert a counter to an array using a specified key index. Infer the dimension of * the returned vector from the index. */ public static <E> double[] asArray(Counter<E> counter, Index<E> index) { return Counters.asArray(counter, index, index.size()); } /** * Convert a counter to an array using a specified key index. This method does *not* expand * the index, so all keys in the set keys(counter) - keys(index) are not added to the * output array. Also note that if counter is being used as a sparse array, the result * will be a dense array with zero entries. * * @return the values corresponding to the index */ public static <E> double[] asArray(Counter<E> counter, Index<E> index, int dimension) { if (index.size() == 0) { throw new IllegalArgumentException("Empty index"); } Set<E> keys = counter.keySet(); double[] array = new double[dimension]; for (E key : keys) { int i = index.indexOf(key); if (i >= 0) { array[i] = counter.getCount(key); } } return array; } /** * Convert a counter to an array, the order of the array is random */ public static <E> double[] asArray(Counter<E> counter) { Set<E> keys = counter.keySet(); double[] array = new double[counter.size()]; int i = 0; for (E key : keys) { array[i] = counter.getCount(key); i++; } return array; } /** * Creates a new TwoDimensionalCounter where all the counts are scaled by d. * Internally, uses Counters.scale(); * * @return The TwoDimensionalCounter */ public static <T1, T2> TwoDimensionalCounter<T1, T2> scale(TwoDimensionalCounter<T1, T2> c, double d) { TwoDimensionalCounter<T1, T2> result = new TwoDimensionalCounter<>(c.getOuterMapFactory(), c.getInnerMapFactory()); for (T1 key : c.firstKeySet()) { ClassicCounter<T2> ctr = c.getCounter(key); result.setCounter(key, scale(ctr, d)); } return result; } static final Random RAND = new Random(); /** * Does not assumes c is normalized. * * @return A sample from c */ public static <T> T sample(Counter<T> c, Random rand) { // OMITTED: Seems like there should be a way to directly check if T is comparable // Set<T> keySet = c.keySet(); // if (!keySet.isEmpty() && keySet.iterator().next() instanceof Comparable) { // List l = new ArrayList<T>(keySet); // Collections.sort(l); // objects = l; // } else { // throw new RuntimeException("Results won't be stable since Counters keys are comparable."); // } if (rand == null) rand = RAND; double r = rand.nextDouble() * c.totalCount(); double total = 0.0; for (T t : c.keySet()) { // arbitrary ordering, but presumably stable total += c.getCount(t); if (total >= r) return t; } // only chance of reaching here is if c isn't properly normalized, or if // double math makes total<1.0 return c.keySet().iterator().next(); } /** * Does not assumes c is normalized. * * @return A sample from c */ public static <T> T sample(Counter<T> c) { return sample(c, null); } /** * Returns a counter where each element corresponds to the normalized count of * the corresponding element in c raised to the given power. */ public static <E> Counter<E> powNormalized(Counter<E> c, double temp) { Counter<E> d = c.getFactory().create(); double total = c.totalCount(); for (E e : c.keySet()) { d.setCount(e, Math.pow(c.getCount(e) / total, temp)); } return d; } public static <T> Counter<T> pow(Counter<T> c, double temp) { Counter<T> d = c.getFactory().create(); for (T t : c.keySet()) { d.setCount(t, Math.pow(c.getCount(t), temp)); } return d; } public static <T> void powInPlace(Counter<T> c, double temp) { for (T t : c.keySet()) { c.setCount(t, Math.pow(c.getCount(t), temp)); } } public static <T> Counter<T> exp(Counter<T> c) { Counter<T> d = c.getFactory().create(); for (T t : c.keySet()) { d.setCount(t, Math.exp(c.getCount(t))); } return d; } public static <T> void expInPlace(Counter<T> c) { for (T t : c.keySet()) { c.setCount(t, Math.exp(c.getCount(t))); } } public static <T> Counter<T> diff(Counter<T> goldFeatures, Counter<T> guessedFeatures) { Counter<T> result = goldFeatures.getFactory().create(); for (T key : Sets.union(goldFeatures.keySet(), guessedFeatures.keySet())) { result.setCount(key, goldFeatures.getCount(key) - guessedFeatures.getCount(key)); } retainNonZeros(result); return result; } /** * Default equality comparison for two counters potentially backed by * alternative implementations. */ public static <E> boolean equals(Counter<E> o1, Counter<E> o2) { return equals(o1, o2, 0.0); } /** * Equality comparison between two counters, allowing for a tolerance fudge factor. */ public static <E> boolean equals(Counter<E> o1, Counter<E> o2, double tolerance) { if (o1 == o2) { return true; } if (Math.abs(o1.totalCount() - o2.totalCount()) > tolerance) { return false; } if (!o1.keySet().equals(o2.keySet())) { return false; } for (E key : o1.keySet()) { if (Math.abs(o1.getCount(key) - o2.getCount(key)) > tolerance) { return false; } } return true; } /** * Returns unmodifiable view of the counter. changes to the underlying Counter * are written through to this Counter. * * @param counter * The counter * @return unmodifiable view of the counter */ public static <T> Counter<T> unmodifiableCounter(final Counter<T> counter) { return new AbstractCounter<T>() { public void clear() { throw new UnsupportedOperationException(); } public boolean containsKey(T key) { return counter.containsKey(key); } public double getCount(Object key) { return counter.getCount(key); } public Factory<Counter<T>> getFactory() { return counter.getFactory(); } public double remove(T key) { throw new UnsupportedOperationException(); } public void setCount(T key, double value) { throw new UnsupportedOperationException(); } @Override public double incrementCount(T key, double value) { throw new UnsupportedOperationException(); } @Override public double incrementCount(T key) { throw new UnsupportedOperationException(); } @Override public double logIncrementCount(T key, double value) { throw new UnsupportedOperationException(); } public int size() { return counter.size(); } public double totalCount() { return counter.totalCount(); } public Collection<Double> values() { return counter.values(); } public Set<T> keySet() { return Collections.unmodifiableSet(counter.keySet()); } public Set<Entry<T, Double>> entrySet() { return Collections.unmodifiableSet(new AbstractSet<Map.Entry<T, Double>>() { @Override public Iterator<Entry<T, Double>> iterator() { return new Iterator<Entry<T, Double>>() { final Iterator<Entry<T, Double>> inner = counter.entrySet().iterator(); public boolean hasNext() { return inner.hasNext(); } public Entry<T, Double> next() { return new Map.Entry<T, Double>() { final Entry<T, Double> e = inner.next(); @Override public T getKey() { return e.getKey(); } @Override @SuppressWarnings( { "UnnecessaryBoxing", "UnnecessaryUnboxing" }) public Double getValue() { return Double.valueOf(e.getValue().doubleValue()); } @Override public Double setValue(Double value) { throw new UnsupportedOperationException(); } }; } @Override public void remove() { throw new UnsupportedOperationException(); } }; } @Override public int size() { return counter.size(); } }); } @Override public void setDefaultReturnValue(double rv) { throw new UnsupportedOperationException(); } @Override public double defaultReturnValue() { return counter.defaultReturnValue(); } /** * {@inheritDoc} */ public void prettyLog(RedwoodChannels channels, String description) { PrettyLogger.log(channels, description, asMap(this)); } }; } // end unmodifiableCounter() /** * Returns a counter whose keys are the elements in this priority queue, and * whose counts are the priorities in this queue. In the event there are * multiple instances of the same element in the queue, the counter's count * will be the sum of the instances' priorities. * */ public static <E> Counter<E> asCounter(FixedPrioritiesPriorityQueue<E> p) { FixedPrioritiesPriorityQueue<E> pq = p.clone(); ClassicCounter<E> counter = new ClassicCounter<>(); while (pq.hasNext()) { double priority = pq.getPriority(); E element = pq.next(); counter.incrementCount(element, priority); } return counter; } /** * Returns a counter view of the given map. Infers the numeric type of the * values from the first element in map.values(). */ @SuppressWarnings("unchecked") public static <E, N extends Number> Counter<E> fromMap(final Map<E, N> map) { if (map.isEmpty()) { throw new IllegalArgumentException("Map must have at least one element" + " to infer numeric type; add an element first or use e.g." + " fromMap(map, Integer.class)"); } return fromMap(map, (Class) map.values().iterator().next().getClass()); } /** * Returns a counter view of the given map. The type parameter is the type of * the values in the map, which because of Java's generics type erasure, can't * be discovered by reflection if the map is currently empty. */ public static <E, N extends Number> Counter<E> fromMap(final Map<E, N> map, final Class<N> type) { // get our initial total double initialTotal = 0.0; for (Map.Entry<E, N> entry : map.entrySet()) { initialTotal += entry.getValue().doubleValue(); } // and pass it in to the returned inner class with a final variable final double initialTotalFinal = initialTotal; return new AbstractCounter<E>() { double total = initialTotalFinal; double defRV = 0.0; @Override public void clear() { map.clear(); total = 0.0; } @Override public boolean containsKey(E key) { return map.containsKey(key); } @Override public void setDefaultReturnValue(double rv) { defRV = rv; } @Override public double defaultReturnValue() { return defRV; } @Override @SuppressWarnings("unchecked") public boolean equals(Object o) { if (this == o) { return true; } else if (!(o instanceof Counter)) { return false; } else { return Counters.equals(this, (Counter<E>) o); } } @Override public int hashCode() { return map.hashCode(); } public Set<Entry<E, Double>> entrySet() { return new AbstractSet<Entry<E, Double>>() { Set<Entry<E, N>> entries = map.entrySet(); @Override public Iterator<Entry<E, Double>> iterator() { return new Iterator<Entry<E, Double>>() { Iterator<Entry<E, N>> it = entries.iterator(); Entry<E, N> lastEntry; // = null; public boolean hasNext() { return it.hasNext(); } public Entry<E, Double> next() { final Entry<E, N> entry = it.next(); lastEntry = entry; return new Entry<E, Double>() { public E getKey() { return entry.getKey(); } public Double getValue() { return entry.getValue().doubleValue(); } public Double setValue(Double value) { final double lastValue = entry.getValue().doubleValue(); double rv; if (type == Double.class) { rv = ErasureUtils.<Entry<E, Double>> uncheckedCast(entry).setValue(value); } else if (type == Integer.class) { rv = ErasureUtils.<Entry<E, Integer>> uncheckedCast(entry).setValue(value.intValue()); } else if (type == Float.class) { rv = ErasureUtils.<Entry<E, Float>> uncheckedCast(entry).setValue(value.floatValue()); } else if (type == Long.class) { rv = ErasureUtils.<Entry<E, Long>> uncheckedCast(entry).setValue(value.longValue()); } else if (type == Short.class) { rv = ErasureUtils.<Entry<E, Short>> uncheckedCast(entry).setValue(value.shortValue()); } else { throw new RuntimeException("Unrecognized numeric type in wrapped counter"); } // need to call getValue().doubleValue() to make sure // we keep the same precision as the underlying map total += entry.getValue().doubleValue() - lastValue; return rv; } }; } public void remove() { total -= lastEntry.getValue().doubleValue(); it.remove(); } }; } @Override public int size() { return map.size(); } }; } public double getCount(Object key) { final Number value = map.get(key); return value != null ? value.doubleValue() : defRV; } public Factory<Counter<E>> getFactory() { return new Factory<Counter<E>>() { private static final long serialVersionUID = -4063129407369590522L; public Counter<E> create() { // return a HashMap backed by the same numeric type to // keep the precision of the returned counter consistent with // this one's precision return fromMap(Generics.<E, N>newHashMap(), type); } }; } public Set<E> keySet() { return new AbstractSet<E>() { @Override public Iterator<E> iterator() { return new Iterator<E>() { Iterator<E> it = map.keySet().iterator(); public boolean hasNext() { return it.hasNext(); } public E next() { return it.next(); } public void remove() { throw new UnsupportedOperationException("Cannot remove from key set"); } }; } @Override public int size() { return map.size(); } }; } public double remove(E key) { final Number removed = map.remove(key); if (removed != null) { final double rv = removed.doubleValue(); total -= rv; return rv; } return defRV; } public void setCount(E key, double value) { final Double lastValue; double newValue; if (type == Double.class) { lastValue = ErasureUtils.<Map<E, Double>> uncheckedCast(map).put(key, value); newValue = value; } else if (type == Integer.class) { final Integer last = ErasureUtils.<Map<E, Integer>> uncheckedCast(map).put(key, (int) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((int) value); } else if (type == Float.class) { final Float last = ErasureUtils.<Map<E, Float>> uncheckedCast(map).put(key, (float) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((float) value); } else if (type == Long.class) { final Long last = ErasureUtils.<Map<E, Long>> uncheckedCast(map).put(key, (long) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((long) value); } else if (type == Short.class) { final Short last = ErasureUtils.<Map<E, Short>> uncheckedCast(map).put(key, (short) value); lastValue = last != null ? last.doubleValue() : null; newValue = ((short) value); } else { throw new RuntimeException("Unrecognized numeric type in wrapped counter"); } // need to use newValue instead of value to make sure we // keep same precision as underlying map. total += newValue - (lastValue != null ? lastValue : 0); } public int size() { return map.size(); } public double totalCount() { return total; } public Collection<Double> values() { return new AbstractCollection<Double>() { @Override public Iterator<Double> iterator() { return new Iterator<Double>() { final Iterator<N> it = map.values().iterator(); public boolean hasNext() { return it.hasNext(); } public Double next() { return it.next().doubleValue(); } public void remove() { throw new UnsupportedOperationException("Cannot remove from values collection"); } }; } @Override public int size() { return map.size(); } }; } /** * {@inheritDoc} */ public void prettyLog(RedwoodChannels channels, String description) { PrettyLogger.log(channels, description, map); } }; } // end fromMap() /** * Returns a map view of the given counter. */ public static <E> Map<E, Double> asMap(final Counter<E> counter) { return new AbstractMap<E, Double>() { @Override public int size() { return counter.size(); } @Override public Set<Entry<E, Double>> entrySet() { return counter.entrySet(); } @Override @SuppressWarnings("unchecked") public boolean containsKey(Object key) { return counter.containsKey((E) key); } @Override @SuppressWarnings("unchecked") public Double get(Object key) { return counter.getCount((E) key); } @Override public Double put(E key, Double value) { double last = counter.getCount(key); counter.setCount(key, value); return last; } @Override @SuppressWarnings("unchecked") public Double remove(Object key) { return counter.remove((E) key); } @Override public Set<E> keySet() { return counter.keySet(); } }; } /** * Check if this counter is a uniform distribution. * That is, it should sum to 1.0, and every value should be equal to every other value. * @param distribution The distribution to check. * @param tolerance The tolerance for floating point error, in both the equality and total count checks. * @param <E> The type of the counter. * @return True if this counter is the uniform distribution over its domain. */ public static <E> boolean isUniformDistribution(Counter<E> distribution, double tolerance) { double value = Double.NaN; double totalCount = 0.0; for (double val : distribution.values()) { if (Double.isNaN(value)) { value = val; } if (Math.abs(val - value) > tolerance) { return false; } totalCount += val; } return Math.abs(totalCount - 1.0) < tolerance; } /** * Default comparator for breaking ties in argmin and argmax. * //TODO: What type should this be? * // Unused, so who cares? * private static final Comparator<Object> hashCodeComparator = * new Comparator<Object>() { * public int compare(Object o1, Object o2) { * return o1.hashCode() - o2.hashCode(); * } * * public boolean equals(Comparator comparator) { * return (comparator == this); * } * }; */ /** * Comparator that uses natural ordering. Returns 0 if o1 is not Comparable. */ static class NaturalComparator<E> implements Comparator<E> { public NaturalComparator() { } @Override public String toString() { return "NaturalComparator"; } @SuppressWarnings("unchecked") public int compare(E o1, E o2) { if (o1 instanceof Comparable) { return (((Comparable<E>) o1).compareTo(o2)); } return 0; // soft-fail } } /** * * @param <E> * @param originalCounter * @return a copy of the original counter */ public static <E> Counter<E> getCopy(Counter<E> originalCounter) { Counter<E> copyCounter = new ClassicCounter<>(); copyCounter.addAll(originalCounter); return copyCounter; } /** * Places the maximum of first and second keys values in the first counter. * @param <E> */ public static <E> void maxInPlace(Counter<E> target, Counter<E> other) { for(E e: CollectionUtils.union(other.keySet(), target.keySet())){ target.setCount(e, Math.max(target.getCount(e), other.getCount(e))); } } /** * Places the minimum of first and second keys values in the first counter. * @param <E> */ public static <E> void minInPlace(Counter<E> target, Counter<E> other){ for(E e: CollectionUtils.union(other.keySet(), target.keySet())){ target.setCount(e, Math.min(target.getCount(e), other.getCount(e))); } } /** * Retains the minimal set of top keys such that their count sum is more than thresholdCount. * @param counter * @param thresholdCount */ public static<E> void retainTopMass(Counter<E> counter, double thresholdCount){ PriorityQueue<E> queue = Counters.toPriorityQueue(counter); counter.clear(); double mass = 0; while (mass < thresholdCount && !queue.isEmpty()) { double value = queue.getPriority(); E key = queue.removeFirst(); counter.setCount(key, value); mass += value; } } public static<A,B> void divideInPlace(TwoDimensionalCounter<A, B> counter, double divisor) { for(Entry<A, ClassicCounter<B>> c: counter.entrySet()){ Counters.divideInPlace(c.getValue(), divisor); } counter.recomputeTotal(); } public static<E> double pearsonsCorrelationCoefficient(Counter<E> x, Counter<E> y){ double stddevX = Counters.standardDeviation(x); double stddevY = Counters.standardDeviation(y); double meanX = Counters.mean(x); double meanY = Counters.mean(y); Counter<E> t1 = Counters.add(x, -meanX); Counter<E> t2 = Counters.add(y, -meanY); Counters.divideInPlace(t1, stddevX); Counters.divideInPlace(t2, stddevY); return Counters.dotProduct(t1, t2)/ (double)(x.size() -1); } public static<E> double spearmanRankCorrelation(Counter<E> x, Counter<E> y){ Counter<E> xrank = Counters.toTiedRankCounter(x); Counter<E> yrank = Counters.toTiedRankCounter(y); return Counters.pearsonsCorrelationCoefficient(xrank, yrank); } /** * ensures that counter t has all keys in keys. If the counter does not have the keys, then add the key with count value. * Note that it does not change counts that exist in the counter */ public static<E> void ensureKeys(Counter<E> t, Collection<E> keys, double value){ for(E k: keys){ if(!t.containsKey(k)) t.setCount(k, value); } } public static<E> List<E> topKeys(Counter<E> t, int topNum){ List<E> list = new ArrayList<>(); PriorityQueue<E> q = Counters.toPriorityQueue(t); int num = 0; while(!q.isEmpty() && num < topNum){ num++; list.add(q.removeFirst()); } return list; } public static<E> List<Pair<E, Double>> topKeysWithCounts(Counter<E> t, int topNum){ List<Pair<E, Double>> list = new ArrayList<>(); PriorityQueue<E> q = Counters.toPriorityQueue(t); int num = 0; while(!q.isEmpty() && num < topNum){ num++; E k = q.removeFirst(); list.add(new Pair<>(k, t.getCount(k))); } return list; } public static<E> Counter<E> getFCounter(Counter<E> precision, Counter<E> recall, double beta){ Counter<E> fscores = new ClassicCounter<>(); for(E k: precision.keySet()){ fscores.setCount(k, precision.getCount(k)*recall.getCount(k)*(1+beta*beta)/(beta*beta*precision.getCount(k) + recall.getCount(k))); } return fscores; } public static <E> void transformValuesInPlace(Counter<E> counter, Function<Double, Double> func){ for(E key: counter.keySet()){ counter.setCount(key, func.apply(counter.getCount(key))); } } public static<E> Counter<E> getCounts(Counter<E> c, Collection<E> keys){ Counter<E> newcounter = new ClassicCounter<>(); for(E k : keys) newcounter.setCount(k, c.getCount(k)); return newcounter; } public static<E> void retainKeys(Counter<E> counter, Function<E, Boolean> retainFunction) { Set<E> remove = new HashSet<>(); for(Entry<E, Double> en: counter.entrySet()){ if(!retainFunction.apply(en.getKey())){ remove.add(en.getKey()); } } Counters.removeKeys(counter, remove); } public static<E, E2> Counter<E> flatten(Map<E2, Counter<E>> hier){ Counter<E> flat = new ClassicCounter<>(); for(Entry<E2, Counter<E>> en: hier.entrySet()){ flat.addAll(en.getValue()); } return flat; } /** * Returns true if the given counter contains only finite, non-NaN values. * @param counts The counter to validate. * @param <E> The parameterized type of the counter. * @return True if the counter is finite and not NaN on every value. */ public static <E> boolean isFinite(Counter<E> counts) { for (double value : counts.values()) { if (Double.isInfinite(value) || Double.isNaN(value)) { return false; } } return true; } }