package edu.stanford.nlp.stats;
import java.util.*;
import java.io.*;
import junit.framework.TestCase;
import edu.stanford.nlp.util.Factory;
import edu.stanford.nlp.io.IOUtils;
/**
* Base tests that should work on any type of Counter. This class
* is subclassed by e.g., {@link ClassicCounterTest} to provide the
* particular Counter instance being tested.
*/
public abstract class CounterTestBase extends TestCase {
private Counter<String> c;
private final boolean integral;
private static final double TOLERANCE = 0.001;
public CounterTestBase(Counter<String> c) {
this(c, false);
}
public CounterTestBase(Counter<String> c, boolean integral) {
this.c = c;
this.integral = integral;
}
@Override
public void setUp() {
c.clear();
}
public void testClassicCounterHistoricalMain() {
c.setCount("p", 0);
c.setCount("q", 2);
ClassicCounter<String> small_c = new ClassicCounter<String>(c);
Counter<String> c7 = c.getFactory().create();
c7.addAll(c);
assertEquals(c.totalCount(), 2.0);
c.incrementCount("p");
assertEquals(c.totalCount(), 3.0);
c.incrementCount("p", 2.0);
assertEquals(Counters.min(c), 2.0);
assertEquals(Counters.argmin(c), "q");
// Now p is p=3.0, q=2.0
c.setCount("w", -5.0);
c.setCount("x", -4.5);
List<String> biggestKeys = new ArrayList<String>(c.keySet());
assertEquals(biggestKeys.size(), 4);
Collections.sort(biggestKeys, Counters.toComparator(c, false, true));
assertEquals("w", biggestKeys.get(0));
assertEquals("x", biggestKeys.get(1));
assertEquals("p", biggestKeys.get(2));
assertEquals("q", biggestKeys.get(3));
assertEquals(Counters.min(c), -5.0, TOLERANCE);
assertEquals(Counters.argmin(c), "w");
assertEquals(Counters.max(c), 3.0, TOLERANCE);
assertEquals(Counters.argmax(c), "p");
if (integral) {
assertEquals(Counters.mean(c), -1.0);
} else {
assertEquals(Counters.mean(c), -1.125, TOLERANCE);
}
if ( ! integral) {
// only do this for floating point counters. Too much bother to rewrite
c.setCount("x", -2.5);
ClassicCounter<String> c2 = new ClassicCounter<String>(c);
assertEquals(3.0, c2.getCount("p"));
assertEquals(2.0, c2.getCount("q"));
assertEquals(-5.0, c2.getCount("w"));
assertEquals(-2.5, c2.getCount("x"));
Counter<String> c3 = c.getFactory().create();
for (String str: c2.keySet()) {
c3.incrementCount(str);
}
assertEquals(1.0, c3.getCount("p"));
assertEquals(1.0, c3.getCount("q"));
assertEquals(1.0, c3.getCount("w"));
assertEquals(1.0, c3.getCount("x"));
Counters.addInPlace(c2, c3, 10.0);
assertEquals(13.0, c2.getCount("p"));
assertEquals(12.0, c2.getCount("q"));
assertEquals(5.0, c2.getCount("w"));
assertEquals(7.5, c2.getCount("x"));
c3.addAll(c);
assertEquals(4.0, c3.getCount("p"));
assertEquals(3.0, c3.getCount("q"));
assertEquals(-4.0, c3.getCount("w"));
assertEquals(-1.5, c3.getCount("x"));
Counters.subtractInPlace(c3, c);
assertEquals(1.0, c3.getCount("p"));
assertEquals(1.0, c3.getCount("q"));
assertEquals(1.0, c3.getCount("w"));
assertEquals(1.0, c3.getCount("x"));
for (String str : c.keySet()) {
c3.incrementCount(str);
}
assertEquals(2.0, c3.getCount("p"));
assertEquals(2.0, c3.getCount("q"));
assertEquals(2.0, c3.getCount("w"));
assertEquals(2.0, c3.getCount("x"));
Counters.divideInPlace(c2, c3);
assertEquals(6.5, c2.getCount("p"));
assertEquals(6.0, c2.getCount("q"));
assertEquals(2.5, c2.getCount("w"));
assertEquals(3.75, c2.getCount("x"));
Counters.divideInPlace(c2, 0.5);
assertEquals(13.0, c2.getCount("p"));
assertEquals(12.0, c2.getCount("q"));
assertEquals(5.0, c2.getCount("w"));
assertEquals(7.5, c2.getCount("x"));
Counters.multiplyInPlace(c2, 2.0);
assertEquals(26.0, c2.getCount("p"));
assertEquals(24.0, c2.getCount("q"));
assertEquals(10.0, c2.getCount("w"));
assertEquals(15.0, c2.getCount("x"));
Counters.divideInPlace(c2, 2.0);
assertEquals(13.0, c2.getCount("p"));
assertEquals(12.0, c2.getCount("q"));
assertEquals(5.0, c2.getCount("w"));
assertEquals(7.5, c2.getCount("x"));
for (String str : c2.keySet()) {
c2.incrementCount(str);
}
assertEquals(14.0, c2.getCount("p"));
assertEquals(13.0, c2.getCount("q"));
assertEquals(6.0, c2.getCount("w"));
assertEquals(8.5, c2.getCount("x"));
for (String str : c.keySet()) {
c2.incrementCount(str);
}
assertEquals(15.0, c2.getCount("p"));
assertEquals(14.0, c2.getCount("q"));
assertEquals(7.0, c2.getCount("w"));
assertEquals(9.5, c2.getCount("x"));
c2.addAll(small_c);
assertEquals(15.0, c2.getCount("p"));
assertEquals(16.0, c2.getCount("q"));
assertEquals(7.0, c2.getCount("w"));
assertEquals(9.5, c2.getCount("x"));
assertEquals(new HashSet<String>(Arrays.asList("p", "q")), Counters.keysAbove(c2, 14));
assertEquals(new HashSet<String>(Arrays.asList("q")), Counters.keysAt(c2, 16));
assertEquals(new HashSet<String>(Arrays.asList("x", "w")), Counters.keysBelow(c2, 9.5));
Counters.addInPlace(c2,small_c, -6);
assertEquals(15.0, c2.getCount("p"));
assertEquals(4.0, c2.getCount("q"));
assertEquals(7.0, c2.getCount("w"));
assertEquals(9.5, c2.getCount("x"));
Counters.subtractInPlace(c2, small_c);
Counters.subtractInPlace(c2, small_c);
Counters.retainNonZeros(c2);
assertEquals(15.0, c2.getCount("p"));
assertFalse(c2.containsKey("q"));
assertEquals(7.0, c2.getCount("w"));
assertEquals(9.5, c2.getCount("x"));
}
// serialize to Stream
if (c instanceof Serializable) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(baos));
out.writeObject(c);
out.close();
// reconstitute
byte[] bytes = baos.toByteArray();
ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new ByteArrayInputStream(bytes)));
c = IOUtils.readObjectFromObjectStream(in);
in.close();
if (!this.integral) {
assertEquals(-2.5, c.totalCount());
assertEquals(-5.0, Counters.min(c));
assertEquals("w", Counters.argmin(c));
}
c.clear();
if (!this.integral) {
assertEquals(0.0, c.totalCount());
}
} catch (IOException ioe) {
fail("IOException: " + ioe);
} catch (ClassNotFoundException cce) {
fail("ClassNotFoundException: " + cce);
}
}
}
public void testFactory() {
Factory<Counter<String>> fcs = c.getFactory();
Counter<String> c2 = fcs.create();
c2.incrementCount("fr");
c2.incrementCount("de");
c2.incrementCount("es", -3);
Counter<String> c3 = fcs.create();
c3.decrementCount("es");
Counter<String> c4 = fcs.create();
c4.incrementCount("fr");
c4.setCount("es", -3);
c4.setCount("de", 1.0);
assertEquals("Testing factory and counter equality", c2, c4);
assertEquals("Testing factory", c2.totalCount(), -1.0);
c3.addAll(c2);
assertEquals(c3.keySet().size(), 3);
assertEquals(c3.size(), 3);
assertEquals("Testing addAll", -2.0, c3.totalCount());
}
public void testReturnValue() {
c.setDefaultReturnValue(-1);
assertEquals(c.defaultReturnValue(), -1.0);
assertEquals(c.getCount("-!-"), -1.0);
c.setDefaultReturnValue(0.0);
assertEquals(c.getCount("-!-"), 0.0);
}
public void testSetCount() {
c.clear();
c.setCount("p", 0);
c.setCount("q", 2);
assertEquals("Failed setCount", 2.0, c.totalCount());
assertEquals("Failed setCount", 2.0, c.getCount("q"));
}
public void testIncrement() {
c.clear();
assertEquals(0., c.getCount("r"));
assertEquals(1., c.incrementCount("r"));
assertEquals(1., c.getCount("r"));
c.setCount("p", 0);
c.setCount("q", 2);
assertEquals(true, c.containsKey("q"));
assertEquals(false, c.containsKey("!!!"));
assertEquals(0., c.getCount("p"));
assertEquals(1., c.incrementCount("p"));
assertEquals(1., c.getCount("p"));
assertEquals(4., c.totalCount());
c.decrementCount("s", 5.0);
assertEquals(-5.0, c.getCount("s"));
c.remove("s");
assertEquals(4.0, c.totalCount());
}
public void testIncrement2() {
c.clear();
c.setCount("p", .5);
c.setCount("q", 2);
if (integral) {
assertEquals(3., c.incrementCount("p", 3.5));
assertEquals(3., c.getCount("p"));
assertEquals(5., c.totalCount());
} else {
assertEquals(4., c.incrementCount("p", 3.5));
assertEquals(4., c.getCount("p"));
assertEquals(6., c.totalCount());
}
}
public void testLogIncrement() {
c.clear();
c.setCount("p", Math.log(.5));
// System.out.println(c.getCount("p"));
c.setCount("q", Math.log(.2));
// System.out.println(c.getCount("q"));
if (integral) {
// 0.5 gives 0 and 0.3 gives -1, so -1
double ans = c.logIncrementCount("p", Math.log(.3));
// System.out.println(ans);
assertEquals(0., ans, .0001);
assertEquals(-1., c.totalCount(), .0001);
} else {
assertEquals(Math.log(.5+.3), c.logIncrementCount("p", Math.log(.3)), .0001);
assertEquals(Math.log(.5+.3)+Math.log(.2), c.totalCount(), .0001);
}
}
public void testEntrySet() {
c.clear();
c.setCount("r", 3.0);
c.setCount("p", 1.0);
c.setCount("q", 2.0);
c.setCount("s", 4.0);
assertEquals(10.0, c.totalCount());
assertEquals(1.0, c.getCount("p"));
for (Map.Entry<String,Double> entry : c.entrySet()) {
if (entry.getKey().equals("p")) {
assertEquals(1.0, entry.setValue(3.0));
assertEquals(3.0, entry.getValue());
}
}
assertEquals(3.0, c.getCount("p"));
assertEquals(12.0, c.totalCount());
Collection<Double> vals = c.values();
double tot = 0.0;
for (double d : vals) {
tot += d;
}
assertEquals("Testing values()", 12.0, tot);
}
public void testComparators() {
c.clear();
c.setCount("b", 3.0);
c.setCount("p", -5.0);
c.setCount("a", 2.0);
c.setCount("s", 4.0);
List<String> list = new ArrayList<String>(c.keySet());
Comparator<String> cmp = Counters.toComparator(c);
Collections.sort(list, cmp);
assertEquals(4, list.size());
assertEquals("p", list.get(0));
assertEquals("a", list.get(1));
assertEquals("b", list.get(2));
assertEquals("s", list.get(3));
Comparator<String> cmp2 = Counters.toComparatorDescending(c);
Collections.sort(list, cmp2);
assertEquals(4, list.size());
assertEquals("p", list.get(3));
assertEquals("a", list.get(2));
assertEquals("b", list.get(1));
assertEquals("s", list.get(0));
Comparator<String> cmp3 = Counters.toComparator(c, true, true);
Collections.sort(list, cmp3);
assertEquals(4, list.size());
assertEquals("p", list.get(3));
assertEquals("a", list.get(0));
assertEquals("b", list.get(1));
assertEquals("s", list.get(2));
Comparator<String> cmp4 = Counters.toComparator(c, false, true);
Collections.sort(list, cmp4);
assertEquals(4, list.size());
assertEquals("p", list.get(0));
assertEquals("a", list.get(3));
assertEquals("b", list.get(2));
assertEquals("s", list.get(1));
Comparator<String> cmp5 = Counters.toComparator(c, false, false);
Collections.sort(list, cmp5);
assertEquals(4, list.size());
assertEquals("p", list.get(3));
assertEquals("a", list.get(2));
assertEquals("b", list.get(1));
assertEquals("s", list.get(0));
}
public void testClear() {
c.incrementCount("xy", 30);
c.clear();
assertEquals(0.0, c.totalCount());
}
}