package edu.cmu.graphchi.walks.distribution; import edu.cmu.graphchi.walks.distributions.DiscreteDistribution; import edu.cmu.graphchi.util.IdCount; import org.junit.Test; import java.util.*; import static org.junit.Assert.*; import static org.junit.Assert.assertEquals; /** * @author Aapo Kyrola, akyrola@twitter.com, akyrola@cs.cmu.edu */ public class TestDiscreteDistribution { @Test public void testTrivialConstruction() { DiscreteDistribution empty = new DiscreteDistribution(new int[0]); assertEquals(0, empty.getCount(0)); assertEquals(0, empty.getCount(1)); DiscreteDistribution singleton = new DiscreteDistribution(new int[] {876}); assertEquals(1, singleton.getCount(876)); assertEquals(0, singleton.getCount(875)); assertEquals(0, singleton.getCount(877)); } @Test public void testConstruction() { DiscreteDistribution d1 = new DiscreteDistribution(new int[] {1,1,1,8,8,8,9,22,22,22,22,22}); assertEquals(3, d1.getCount(1)); assertEquals(3, d1.getCount(8)); assertEquals(1, d1.getCount(9)); assertEquals(5, d1.getCount(22)); assertEquals(0, d1.getCount(0)); assertEquals(0, d1.getCount(10)); assertEquals(0, d1.getCount(10324324)); } @Test public void testMerge() { DiscreteDistribution d1 = new DiscreteDistribution(new int[] {1,1,1,8,8,8,9,22,22,22,22,22}); DiscreteDistribution d2 = new DiscreteDistribution(new int[] {1,1,1,8,8,8,9,22,22,22,22,22}); DiscreteDistribution merged = DiscreteDistribution.merge(d1, d2); assertEquals(3 * 2, merged.getCount(1)); assertEquals(3 * 2, merged.getCount(8)); assertEquals(2, merged.getCount(9)); assertEquals(5 * 2, merged.getCount(22)); assertEquals(0, merged.getCount(0)); assertEquals(0, merged.getCount(10)); assertEquals(0, merged.getCount(10324324)); DiscreteDistribution d3 = new DiscreteDistribution(new int[] {1,7,8,1000,1000,1000,2000,2000,30000}); DiscreteDistribution merged2 = DiscreteDistribution.merge(d1, d3); assertEquals(4, merged2.getCount(1)); assertEquals(1, merged2.getCount(7)); assertEquals(4, merged2.getCount(8)); assertEquals(1, merged2.getCount(9)); assertEquals(5, merged2.getCount(22)); assertEquals(3, merged2.getCount(1000)); assertEquals(2, merged2.getCount(2000)); assertEquals(1, merged2.getCount(30000)); assertEquals(8, merged2.size()); DiscreteDistribution empty = new DiscreteDistribution(); DiscreteDistribution mergedWithEmpty = DiscreteDistribution.merge(empty, d1); DiscreteDistribution mergedWithEmpty2 = DiscreteDistribution.merge(d2, empty); assertTrue(d1 == mergedWithEmpty); assertTrue(d2 == mergedWithEmpty2); } private int[] toIntArray(ArrayList<Integer> arr) { int[] a = new int[arr.size()]; for(int i=0; i<arr.size(); i++) { a[i] = arr.get(i); } return a; } @Test public void testBigMerge() { Random r = new Random(260379); TreeMap<Integer, Integer> leftSet = new TreeMap<Integer, Integer>(); TreeMap<Integer, Integer> rightSet = new TreeMap<Integer, Integer>(); // There must be collisions and also some omissions for(int i=0; i < 4001; i++) { leftSet.put(r.nextInt(4000), r.nextInt(1000)); rightSet.put(r.nextInt(4000), r.nextInt(1000)); } // Compose ArrayList<Integer> leftArray = new ArrayList<Integer>(4000); for(Map.Entry<Integer ,Integer> e : leftSet.entrySet()) { for(int j=0; j<e.getValue(); j++) leftArray.add(e.getKey()); } ArrayList<Integer> rightArray = new ArrayList<Integer>(4000); for(Map.Entry<Integer ,Integer> e : rightSet.entrySet()) { for(int j=0; j<e.getValue(); j++) rightArray.add(e.getKey()); } DiscreteDistribution leftDist = new DiscreteDistribution(toIntArray(leftArray)); DiscreteDistribution rightDist = new DiscreteDistribution(toIntArray(rightArray)); DiscreteDistribution mergedDist1 = DiscreteDistribution.merge(leftDist, rightDist); DiscreteDistribution mergedDist2 = DiscreteDistribution.merge(rightDist, leftDist); for(int i=0; i < 5000; i++) { int lc = (leftSet.containsKey(i) ? leftSet.get(i) : 0); int rc = (rightSet.containsKey(i) ? rightSet.get(i) : 0); assertEquals(lc, leftDist.getCount(i)); assertEquals(rc, rightDist.getCount(i)); assertEquals(lc + rc, mergedDist1.getCount(i)); assertEquals(lc + rc, mergedDist2.getCount(i)); } } private void insertMultiple(ArrayList<Integer> arr, int val, int n) { for(int i=0; i<n; i++) arr.add(val); } @Test public void testTop() { Random r = new Random(260379); ArrayList<Integer> workArr = new ArrayList<Integer>(); TreeMap<Integer, Integer> countToId = new TreeMap<Integer, Integer>(new Comparator<Integer>() { public int compare(Integer integer, Integer integer1) { return -integer.compareTo(integer1); } }); for(int i=1; i < 200; i+=2) { int n; do { n = r.nextInt(10000); } while (countToId.containsKey(n)); // Unique count keys countToId.put(n, i); insertMultiple(workArr, i, n); } DiscreteDistribution dist = new DiscreteDistribution(toIntArray(workArr)); IdCount[] top = dist.getTop(10); int j = 0; for(Map.Entry <Integer, Integer> e : countToId.entrySet()) { IdCount topEntryJ = top[j]; assertEquals((int)e.getValue(), topEntryJ.id); assertEquals((int)e.getKey(), topEntryJ.count); j++; if (top.length <= j) { assertEquals(10, j); break; } } } @Test public void testTopWithOnes() { DiscreteDistribution d1 = new DiscreteDistribution(new int[] {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,20,21,22,23,200,298}); DiscreteDistribution avoidDist = DiscreteDistribution.createAvoidanceDistribution(new int[]{12, 15, 20}); IdCount[] top = d1.getTop(10); assertEquals(10, top.length); DiscreteDistribution mergedWithAvoid = DiscreteDistribution.merge(d1, avoidDist); IdCount[] top2 = mergedWithAvoid.getTop(10); assertEquals(10, top2.length); } @Test public void testFiltering() { DiscreteDistribution d1 = new DiscreteDistribution(new int[] {1,1,1,8,8,8,9,22,22,22,22,22,333,333,333,333,333,333,333,333}); assertEquals(3, d1.getCount(1)); assertEquals(3, d1.getCount(8)); assertEquals(1, d1.getCount(9)); assertEquals(5, d1.getCount(22)); assertEquals(8, d1.getCount(333)); DiscreteDistribution notReallyFiltered = d1.filteredAndShift(1); assertTrue(notReallyFiltered == d1); DiscreteDistribution filtered = d1.filteredAndShift(4); assertEquals(0, filtered.getCount(1)); assertEquals(0, filtered.getCount(8)); assertEquals(0, filtered.getCount(9)); assertEquals(5 - 3, filtered.getCount(22)); assertEquals(8 - 3, filtered.getCount(333)); assertEquals(2, filtered.size()); // Check that avoided ones are not filtered DiscreteDistribution filteredWithAnAvoid = DiscreteDistribution.merge(d1, DiscreteDistribution.createAvoidanceDistribution(new int[]{99, 108})).filteredAndShift(4); assertEquals(0, filteredWithAnAvoid.getCount(1)); assertEquals(0, filteredWithAnAvoid.getCount(8)); assertEquals(0, filteredWithAnAvoid.getCount(9)); assertEquals(5 - 3, filteredWithAnAvoid.getCount(22)); assertEquals(8 - 3, filteredWithAnAvoid.getCount(333)); assertEquals(-1, filteredWithAnAvoid.getCount(99)); assertEquals(-1, filteredWithAnAvoid.getCount(108)); DiscreteDistribution filteredAll = d1.filteredAndShift(100); DiscreteDistribution filteredAll2 = filtered.filteredAndShift(100); assertEquals(0, filteredAll.getCount(1)); assertEquals(0, filteredAll.getCount(8)); assertEquals(0, filteredAll.getCount(9)); assertEquals(0, filteredAll.getCount(22)); assertEquals(0, filteredAll.getCount(333)); assertEquals(0, filteredAll2.getCount(1)); assertEquals(0, filteredAll2.getCount(8)); assertEquals(0, filteredAll2.getCount(9)); assertEquals(0, filteredAll2.getCount(22)); assertEquals(0, filteredAll2.getCount(333)); assertEquals(0, filteredAll.size()); assertEquals(0, filteredAll2.size()); } @Test public void testAvoidance() { Random r = new Random(260379); /* First create some data */ ArrayList<Integer> workArr = new ArrayList<Integer>(); TreeMap<Integer, Integer> countToId = new TreeMap<Integer, Integer>(new Comparator<Integer>() { public int compare(Integer integer, Integer integer1) { return -integer.compareTo(integer1); } }); for(int i=1; i < 200; i++) { int n; do { n = r.nextInt(10000); } while (countToId.containsKey(n)); // Unique count keys countToId.put(n, i); insertMultiple(workArr, i, n); } DiscreteDistribution dist = new DiscreteDistribution(toIntArray(workArr)); /* Then insert some edges to avoid */ int[] avoids = new int[] {0, 2, 4, 32,33, 66, 67,68, 99, 102, 184}; DiscreteDistribution avoidDistr = DiscreteDistribution.createAvoidanceDistribution(avoids); // Test the merge works both ways DiscreteDistribution mergedL = DiscreteDistribution.merge(dist, avoidDistr); DiscreteDistribution mergedR = DiscreteDistribution.merge(avoidDistr, dist); for(int a : avoids) { assertEquals(-1, avoidDistr.getCount(a)); assertEquals(-1, mergedL.getCount(a)); assertEquals(-1, mergedR.getCount(a)); } IdCount[] top = dist.getTop(10); int j = 0; HashSet<Integer> avoidSet = new HashSet<Integer>(); for(int a : avoids) avoidSet.add(a); for(Map.Entry <Integer, Integer> e : countToId.entrySet()) { IdCount topEntryJ = top[j]; if (!avoidSet.contains(e.getKey())) { assertEquals((int)e.getValue(), topEntryJ.id); assertEquals((int)e.getKey(), topEntryJ.count); j++; if (top.length <= j) { assertEquals(10, j); break; } } } } }