package edu.stanford.nlp.stats;
import java.io.*;
import java.util.*;
import edu.stanford.nlp.util.Pair;
import org.junit.Assert;
import junit.framework.TestCase;
public class CountersTest extends TestCase {
private Counter<String> c1;
private Counter<String> c2;
private Counter<String> c8;
private Counter<String> c9;
private static final double TOLERANCE = 0.001;
@Override
protected void setUp() {
Locale.setDefault(Locale.US);
c1 = new ClassicCounter<>();
c1.setCount("p", 1.0);
c1.setCount("q", 2.0);
c1.setCount("r", 3.0);
c1.setCount("s", 4.0);
c2 = new ClassicCounter<>();
c2.setCount("p", 5.0);
c2.setCount("q", 6.0);
c2.setCount("r", 7.0);
c2.setCount("t", 8.0);
c8 = new ClassicCounter<>();
c8.setCount("r", 2.0);
c8.setCount("z", 4.0);
c9 = new ClassicCounter<>();
c9.setCount("z", 4.0);
}
public void testUnion() {
Counter<String> c3 = Counters.union(c1, c2);
assertEquals(c3.getCount("p"), 6.0);
assertEquals(c3.getCount("s"), 4.0);
assertEquals(c3.getCount("t"), 8.0);
assertEquals(c3.totalCount(), 36.0);
}
public void testIntersection() {
Counter<String> c3 = Counters.intersection(c1, c2);
assertEquals(c3.getCount("p"), 1.0);
assertEquals(c3.getCount("q"), 2.0);
assertEquals(c3.getCount("s"), 0.0);
assertEquals(c3.getCount("t"), 0.0);
assertEquals(c3.totalCount(), 6.0);
}
public void testProduct() {
Counter<String> c3 = Counters.product(c1, c2);
assertEquals(c3.getCount("p"), 5.0);
assertEquals(c3.getCount("q"), 12.0);
assertEquals(c3.getCount("r"), 21.0);
assertEquals(c3.getCount("s"), 0.0);
assertEquals(c3.getCount("t"), 0.0);
}
public void testDotProduct() {
double d1 = Counters.dotProduct(c1, c2);
assertEquals(38.0, d1);
double d2 = Counters.dotProduct(c1, c1);
assertEquals(30.0, d2);
double d3 = Counters.optimizedDotProduct(c1, c2);
assertEquals(38.0, d3);
double d4 = Counters.optimizedDotProduct(c1, c1);
assertEquals(30.0, d4);
assertEquals(14.0, Counters.optimizedDotProduct(c2, c8));
assertEquals(14.0, Counters.optimizedDotProduct(c8, c2));
assertEquals(0.0, Counters.optimizedDotProduct(c2, c9));
assertEquals(0.0, Counters.optimizedDotProduct(c9, c2));
}
public void testAbsoluteDifference() {
Counter<String> c3 = Counters.absoluteDifference(c1, c2);
assertEquals(c3.getCount("p"), 4.0);
assertEquals(c3.getCount("q"), 4.0);
assertEquals(c3.getCount("r"), 4.0);
assertEquals(c3.getCount("s"), 4.0);
assertEquals(c3.getCount("t"), 8.0);
Counter<String> c4 = Counters.absoluteDifference(c2, c1);
assertEquals(c4.getCount("p"), 4.0);
assertEquals(c4.getCount("q"), 4.0);
assertEquals(c4.getCount("r"), 4.0);
assertEquals(c4.getCount("s"), 4.0);
assertEquals(c4.getCount("t"), 8.0);
}
@SuppressWarnings("unchecked")
public void testSerialization() {
try {
ByteArrayOutputStream bout = new ByteArrayOutputStream();
ObjectOutputStream oout = new ObjectOutputStream(bout);
oout.writeObject(c1);
byte[] bleh = bout.toByteArray();
ByteArrayInputStream bin = new ByteArrayInputStream(bleh);
ObjectInputStream oin = new ObjectInputStream(bin);
ClassicCounter<String> c3 = (ClassicCounter<String>) oin.readObject();
assertEquals(c3, c1);
} catch (Exception e) {
Assert.fail(e.getMessage());
}
}
public void testMin() {
assertEquals(Counters.min(c1), 1.0);
assertEquals(Counters.min(c2), 5.0);
}
public void testArgmin() {
assertEquals(Counters.argmin(c1), "p");
assertEquals(Counters.argmin(c2), "p");
}
public void testL2Norm() {
ClassicCounter<String> c = new ClassicCounter<>();
c.incrementCount("a", 3);
c.incrementCount("b", 4);
assertEquals(5.0, Counters.L2Norm(c), TOLERANCE);
c.incrementCount("c", 6);
c.incrementCount("d", 4);
c.incrementCount("e", 2);
assertEquals(9.0, Counters.L2Norm(c), TOLERANCE);
}
@SuppressWarnings({ "ConstantMathCall" })
public void testLogNormalize() {
ClassicCounter<String> c = new ClassicCounter<>();
c.incrementCount("a", Math.log(4.0));
c.incrementCount("b", Math.log(2.0));
c.incrementCount("c", Math.log(1.0));
c.incrementCount("d", Math.log(1.0));
Counters.logNormalizeInPlace(c);
assertEquals(c.getCount("a"), -0.693, TOLERANCE);
assertEquals(c.getCount("b"), -1.386, TOLERANCE);
assertEquals(c.getCount("c"), -2.079, TOLERANCE);
assertEquals(c.getCount("d"), -2.079, TOLERANCE);
assertEquals(Counters.logSum(c), 0.0, TOLERANCE);
}
public void testL2Normalize() {
ClassicCounter<String> c = new ClassicCounter<>();
c.incrementCount("a", 4.0);
c.incrementCount("b", 2.0);
c.incrementCount("c", 1.0);
c.incrementCount("d", 2.0);
Counter<String> d = Counters.L2Normalize(c);
assertEquals(d.getCount("a"), 0.8, TOLERANCE);
assertEquals(d.getCount("b"), 0.4, TOLERANCE);
assertEquals(d.getCount("c"), 0.2, TOLERANCE);
assertEquals(d.getCount("d"), 0.4, TOLERANCE);
}
public void testRetainAbove() {
c1 = new ClassicCounter<>();
c1.incrementCount("a", 1.1);
c1.incrementCount("b", 1.0);
c1.incrementCount("c", 0.9);
c1.incrementCount("d", 0);
Set<String> removed = Counters.retainAbove(c1, 1.0);
Set<String> expected = new HashSet<>();
expected.add("c");
expected.add("d");
assertEquals(expected, removed);
assertEquals(1.1, c1.getCount("a"));
assertEquals(1.0, c1.getCount("b"));
assertFalse(c1.containsKey("c"));
assertFalse(c1.containsKey("d"));
}
private final String[] ascending = { "e", "d", "a", "b", "c" };
public void testToSortedList() {
c1 = new ClassicCounter<>();
c1.incrementCount("a", 0.9);
c1.incrementCount("b", 1.0);
c1.incrementCount("c", 1.5);
c1.incrementCount("d", 0.0);
c1.incrementCount("e", -2.0);
List<String> ascendList = Counters.toSortedList(c1, true);
List<String> descendList = Counters.toSortedList(c1);
for (int i = 0; i < ascending.length; i++) {
assertEquals(ascending[i], ascendList.get(i));
assertEquals(ascending[i], descendList.get(ascending.length - i - 1));
}
}
public void testRetainTop() {
c1 = new ClassicCounter<>();
c1.incrementCount("a", 0.9);
c1.incrementCount("b", 1.0);
c1.incrementCount("c", 1.5);
c1.incrementCount("d", 0.0);
c1.incrementCount("e", -2.0);
Counters.retainTop(c1, 3);
assertEquals(3, c1.size());
assertTrue(c1.containsKey("a"));
assertFalse(c1.containsKey("d"));
Counters.retainTop(c1, 1);
assertEquals(1, c1.size());
assertTrue(c1.containsKey("c"));
assertEquals(1.5, c1.getCount("c"));
}
public void testPointwiseMutualInformation() {
Counter<String> x = new ClassicCounter<>();
x.incrementCount("0", 0.8);
x.incrementCount("1", 0.2);
Counter<Integer> y = new ClassicCounter<>();
y.incrementCount(0, 0.25);
y.incrementCount(1, 0.75);
Counter<Pair<String, Integer>> joint;
joint = new ClassicCounter<>();
joint.incrementCount(new Pair<>("0", 0), 0.1);
joint.incrementCount(new Pair<>("0", 1), 0.7);
joint.incrementCount(new Pair<>("1", 0), 0.15);
joint.incrementCount(new Pair<>("1", 1), 0.05);
// Check that correct PMI values are calculated, using tables from
// http://en.wikipedia.org/wiki/Pointwise_mutual_information
double pmi;
Pair<String, Integer> pair;
pair = new Pair<>("0", 0);
pmi = Counters.pointwiseMutualInformation(x, y, joint, pair);
assertEquals(-1, pmi, 10e-5);
pair = new Pair<>("0", 1);
pmi = Counters.pointwiseMutualInformation(x, y, joint, pair);
assertEquals(0.222392421, pmi, 10e-5);
pair = new Pair<>("1", 0);
pmi = Counters.pointwiseMutualInformation(x, y, joint, pair);
assertEquals(1.584962501, pmi, 10e-5);
pair = new Pair<>("1", 1);
pmi = Counters.pointwiseMutualInformation(x, y, joint, pair);
assertEquals(-1.584962501, pmi, 10e-5);
}
public void testToSortedString() {
Counter<String> c = new ClassicCounter<>();
c.setCount("b", 0.25);
c.setCount("a", 0.5);
c.setCount("c", 1.0);
// check full argument version
String result = Counters.toSortedString(c, 5, "%s%.1f", ":", "{%s}");
assertEquals("{c1.0:a0.5:b0.3}", result);
// check version with no wrapper
result = Counters.toSortedString(c, 2, "%2$f %1$s", "\n");
assertEquals("1.000000 c\n0.500000 a", result);
// check some equivalences to other Counters methods
int k = 2;
result = Counters.toSortedString(c, k, "%s=%s", ", ", "[%s]");
assertEquals(Counters.toString(c, k), result);
assertEquals(Counters.toBiggestValuesFirstString(c, k), result);
result = Counters.toSortedString(c, k, "%2$g\t%1$s", "\n", "%s\n");
assertEquals(Counters.toVerticalString(c, k), result);
// test sorting by keys
result = Counters.toSortedByKeysString(c, "%s=>%.2f", "; ", "<%s>");
assertEquals("<a=>0.50; b=>0.25; c=>1.00>", result);
}
public void testHIndex() {
// empty counter
Counter<String> c = new ClassicCounter<>();
assertEquals(0, Counters.hIndex(c));
// two items with 2 or more citations
c.setCount("X", 3);
c.setCount("Y", 2);
c.setCount("Z", 1);
assertEquals(2, Counters.hIndex(c));
// 14 items with 14 or more citations
for (int i = 0; i < 14; ++i) {
c.setCount(String.valueOf(i), 15);
}
assertEquals(14, Counters.hIndex(c));
// 15 items with 15 or more citations
c.setCount("15", 15);
assertEquals(15, Counters.hIndex(c));
}
public void testAddInPlaceCollection() {
// initialize counter
setUp();
List<String> collection = new ArrayList<>();
collection.add("p");
collection.add("p");
collection.add("s");
Counters.addInPlace(c1, collection);
assertEquals(3.0, c1.getCount("p"));
assertEquals(5.0, c1.getCount("s"));
}
public void testRemoveKeys() {
setUp();
Collection<String> c = new ArrayList<>();
c.add("p");
c.add("r");
c.add("s");
Counters.removeKeys(c1, c);
assertEquals(c1.keySet().size(), 1);
Object[] keys = c1.keySet().toArray();
assertEquals(keys[0], "q");
}
public void testRetainTopMass() {
setUp();
System.out.println(Counters.toString(c1, c1.size()));
Counters.retainTopMass(c1, 3);
assertEquals(c1.keySet().toArray()[0], "s");
assertEquals(c1.size(), 1);
}
public void testDivideInPlace() {
TwoDimensionalCounter<String, String> a = new TwoDimensionalCounter<>();
a.setCount("a", "b", 1);
a.setCount("a", "c", 1);
a.setCount("c", "a", 1);
a.setCount("c", "b", 1);
Counters.divideInPlace(a, a.totalCount());
assertEquals(1.0, a.totalCount());
assertEquals(0.25, a.getCount("a", "b"));
}
public void testPearsonsCorrelationCoefficient(){
setUp();
Counters.pearsonsCorrelationCoefficient(c1, c2);
}
public void testToTiedRankCounter(){
setUp();
c1.setCount("t",1.0);
c1.setCount("u",1.0);
c1.setCount("v",2.0);
c1.setCount("z",4.0);
Counter<String> rank = Counters.toTiedRankCounter(c1);
assertEquals(1.5, rank.getCount("z"));
assertEquals(7.0, rank.getCount("t"));
}
public void testTransformWithValuesAdd() {
setUp();
c1.setCount("P",2.0);
System.out.println(c1);
c1 = Counters.transformWithValuesAdd(c1, String::toLowerCase);
System.out.println(c1);
}
public void testEquals() {
setUp();
c1.clear();
c2.clear();
c1.setCount("p", 1.0);
c1.setCount("q", 2.0);
c1.setCount("r", 3.0);
c1.setCount("s", 4.0);
c2.setCount("p", 1.0);
c2.setCount("q", 2.0);
c2.setCount("r", 3.0);
c2.setCount("s", 4.0);
assertTrue(Counters.equals(c1, c2));
c2.setCount("s", 4.1);
assertFalse(Counters.equals(c1, c2));
c2.remove("s");
assertFalse(Counters.equals(c1, c2));
c2.setCount("s", 4.0 + 1e-10);
assertFalse(Counters.equals(c1, c2));
assertTrue(Counters.equals(c1, c2, 1e-5));
c2.setCount("2", 3.0 + 8e-5);
c2.setCount("s", 4.0 + 8e-5);
assertFalse(Counters.equals(c1, c2, 1e-5)); // fails totalCount() equality check
}
public void testJensenShannonDivergence() {
// borrow from ArrayMathTest
Counter<String> a = new ClassicCounter<>();
a.setCount("a", 1.0);
a.setCount("b", 1.0);
a.setCount("c", 7.0);
a.setCount("d", 1.0);
Counter<String> b = new ClassicCounter<>();
b.setCount("b", 1.0);
b.setCount("c", 1.0);
b.setCount("d", 7.0);
b.setCount("e", 1.0);
b.setCount("f", 0.0);
assertEquals(0.46514844544032313, Counters.jensenShannonDivergence(a, b), 1e-5);
Counter<String> c = new ClassicCounter<>(Collections.singletonList("A"));
Counter<String> d = new ClassicCounter<>(Arrays.asList("B", "C"));
assertEquals(1.0, Counters.jensenShannonDivergence(c, d), 1e-5);
}
public void testFlatten() {
Map<String, Counter<String>> h = new HashMap<>();
Counter<String> a = new ClassicCounter<>();
a.setCount("a", 1.0);
a.setCount("b", 1.0);
a.setCount("c", 7.0);
a.setCount("d", 1.0);
Counter<String> b = new ClassicCounter<>();
b.setCount("b", 1.0);
b.setCount("c", 1.0);
b.setCount("d", 7.0);
b.setCount("e", 1.0);
b.setCount("f", 1.0);
h.put("first",a);
h.put("second", b);
Counter<String> flat = Counters.flatten(h);
assertEquals(6, flat.size());
assertEquals(2.0, flat.getCount("b"));
}
public void testSerializeStringCounter() throws IOException {
Counter<String> counts = new ClassicCounter<>();
for (int base = -10; base < 10; ++base) {
if (base == 0) { continue; }
for (int exponent = -100; exponent < 100; ++exponent) {
double number = Math.pow(Math.PI * base, exponent);
counts.setCount(Double.toString(number), number);
}
}
File tmp = File.createTempFile("counts", ".tab.gz");
tmp.deleteOnExit();
Counters.serializeStringCounter(counts, tmp.getPath());
Counter<String> reread = Counters.deserializeStringCounter(tmp.getPath());
for (Map.Entry<String, Double> entry : reread.entrySet()) {
double old = counts.getCount(entry.getKey());
assertEquals(old, entry.getValue(), Math.abs(old) / 1e5 );
}
}
}