package edu.stanford.nlp.stats; import java.util.Set; import edu.stanford.nlp.util.Generics; /** * Static methods for operating on {@link Distributions}s. * * In general, if a method is operating on a pair of Distribution objects, we imagine that the * set of possible keys for each Distribution is the same. * Therefore we require that d1.numberOFKeys = d2.numberOfKeys and that the number of keys in the union * of the two key sets <= numKeys * * * @author Jeff Michels (jmichels@stanford.edu) */ public class Distributions { private Distributions() { } protected static <K> Set<K> getSetOfAllKeys(Distribution<K> d1, Distribution<K> d2) { if (d1.getNumberOfKeys() != d2.getNumberOfKeys()){ throw new RuntimeException("Tried to compare two Distribution<K> objects but d1.numberOfKeys != d2.numberOfKeys"); } Set<K> allKeys = Generics.newHashSet(d1.getCounter().keySet()); allKeys.addAll(d2.getCounter().keySet()); if (allKeys.size() > d1.getNumberOfKeys()){ throw new RuntimeException("Tried to compare two Distribution<K> objects but d1.counter intersect d2.counter > numberOfKeys"); } return allKeys; } /** * Returns a double between 0 and 1 representing the overlap of d1 and d2. * Equals 0 if there is no overlap, equals 1 iff d1==d2 */ public static <K> double overlap(Distribution<K> d1, Distribution<K> d2) { Set<K> allKeys = getSetOfAllKeys(d1, d2); double result = 0.0; double remainingMass1 = 1.0; double remainingMass2 = 1.0; for (K key : allKeys){ double p1 = d1.probabilityOf(key); double p2 = d2.probabilityOf(key); remainingMass1 -= p1; remainingMass2 -= p2; result += Math.min(p1, p2); } result += Math.min(remainingMass1, remainingMass2); return result; } /** * Returns a new Distribution<K> with counts averaged from the two given Distributions. * The average Distribution<K> will contain the union of keys in both * source Distributions, and each count will be the weighted average of the two source * counts for that key, a missing count in one Distribution * is treated as if it has probability equal to that returned by the probabilityOf() function. * * @return A new distribution with counts that are the mean of the resp. counts * in the given distributions with the remaining probability mass adjusted accordingly. */ public static <K> Distribution<K> weightedAverage(Distribution<K> d1, double w1, Distribution<K> d2) { double w2 = 1.0 - w1; Set<K> allKeys = getSetOfAllKeys(d1, d2); int numKeys = d1.getNumberOfKeys(); Counter<K> c = new ClassicCounter<>(); for (K key : allKeys){ double newProbability = d1.probabilityOf(key) * w1 + d2.probabilityOf(key) * w2; c.setCount(key, newProbability); } return (Distribution.getDistributionFromPartiallySpecifiedCounter(c, numKeys)); } public static <K> Distribution<K> average(Distribution<K> d1, Distribution<K> d2) { return weightedAverage(d1, 0.5, d2); } /** * Calculates the KL divergence between the two distributions. * That is, it calculates KL(from || to). * In other words, how well can d1 be represented by d2. * if there is some value in d1 that gets zero prob in d2, then return positive infinity. * * @return The KL divergence between the distributions */ public static <K> double klDivergence(Distribution<K> from, Distribution<K> to) { Set<K> allKeys = getSetOfAllKeys(from, to); int numKeysRemaining = from.getNumberOfKeys(); double result = 0.0; double assignedMass1 = 0.0; double assignedMass2 = 0.0; double log2 = Math.log(2.0); double p1, p2; double epsilon = 1e-10; for (K key : allKeys){ p1 = from.probabilityOf(key); p2 = to.probabilityOf(key); numKeysRemaining--; assignedMass1 += p1; assignedMass2 += p2; if (p1 < epsilon) { continue; } double logFract = Math.log(p1 / p2); if (logFract == Double.POSITIVE_INFINITY) { System.out.println("Didtributions.kldivergence returning +inf: p1=" + p1 + ", p2=" +p2); System.out.flush(); return Double.POSITIVE_INFINITY; // can't recover } result += p1 * (logFract / log2); // express it in log base 2 } if (numKeysRemaining != 0){ p1 = (1.0 - assignedMass1) / numKeysRemaining; if (p1 > epsilon){ p2 = (1.0 - assignedMass2) / numKeysRemaining; double logFract = Math.log(p1 / p2); if (logFract == Double.POSITIVE_INFINITY) { System.out.println("Distributions.klDivergence (remaining mass) returning +inf: p1=" + p1 + ", p2=" +p2); System.out.flush(); return Double.POSITIVE_INFINITY; // can't recover } result += numKeysRemaining * p1 * (logFract / log2); // express it in log base 2 } } return result; } /** * Calculates the Jensen-Shannon divergence between the two distributions. * That is, it calculates 1/2 [KL(d1 || avg(d1,d2)) + KL(d2 || avg(d1,d2))] . * * @return The KL divergence between the distributions */ public static <K> double jensenShannonDivergence(Distribution<K> d1, Distribution<K> d2) { Distribution<K> average = average(d1, d2); double kl1 = klDivergence(d1, average); double kl2 = klDivergence(d2, average); double js = (kl1 + kl2) / 2.0; return js; } /** * Calculates the skew divergence between the two distributions. * That is, it calculates KL(d1 || (d2*skew + d1*(1-skew))) . * In other words, how well can d1 be represented by a "smoothed" d2. * * @return The skew divergence between the distributions */ public static <K> double skewDivergence(Distribution<K> d1, Distribution<K> d2, double skew) { Distribution<K> average = weightedAverage(d2, skew, d1); return klDivergence(d1, average); } /** * Calculates the information radius (aka the Jensen-Shannon divergence) * between the two Distributions. This measure is defined as: * <blockquote> iRad(p,q) = D(p||(p+q)/2)+D(q,(p+q)/2) </blockquote> * where p is one Distribution, q is the other distribution, and D(p||q) is the * KL divergence bewteen p and q. Note that iRad(p,q) = iRad(q,p). * * @return The information radius between the distributions */ public static <K> double informationRadius(Distribution<K> d1, Distribution<K> d2) { Distribution<K> avg = average(d1, d2); // (p+q)/2 return (klDivergence(d1, avg) + klDivergence(d2, avg)); } }