package edu.stanford.nlp.util;
import java.util.Locale;
import junit.framework.TestCase;
/**
* Tests that the output of the ConfusionMatrix is in the expected format.
*
* @author Eric Yeh yeh1@cs.stanford.edu
*/
public class ConfusionMatrixTest extends TestCase {
boolean echo;
public ConfusionMatrixTest() {
this(false);
}
public ConfusionMatrixTest(boolean echo) {
this.echo = echo;
}
public void testBasic() {
String expected = " Guess/Gold C1 C2 C3 Marg. (Guess)\n" +
" C1 2 0 0 2\n" +
" C2 1 0 0 1\n" +
" C3 0 0 1 1\n" +
" Marg. (Gold) 3 0 1\n\n" +
" C1 = a prec=1, recall=0.66667, spec=1, f1=0.8\n" +
" C2 = b prec=0, recall=n/a, spec=0.75, f1=n/a\n" +
" C3 = c prec=1, recall=1, spec=1, f1=1\n";
ConfusionMatrix<String> conf = new ConfusionMatrix<String>(Locale.US);
conf.add("a","a");
conf.add("a","a");
conf.add("b","a");
conf.add("c","c");
String result = conf.printTable();
if (echo) {
System.err.println(result);
} else {
assertEquals(expected, result);
}
}
public void testRealLabels() {
String expected = " Guess/Gold a b c Marg. (Guess)\n" +
" a 2 0 0 2\n" +
" b 1 0 0 1\n" +
" c 0 0 1 1\n" +
" Marg. (Gold) 3 0 1\n\n" +
" a prec=1, recall=0.66667, spec=1, f1=0.8\n" +
" b prec=0, recall=n/a, spec=0.75, f1=n/a\n" +
" c prec=1, recall=1, spec=1, f1=1\n";
ConfusionMatrix<String> conf = new ConfusionMatrix<String>(Locale.US);
conf.setUseRealLabels(true);
conf.add("a","a");
conf.add("a","a");
conf.add("b","a");
conf.add("c","c");
String result = conf.printTable();
if (echo) {
System.err.println(result);
} else {
assertEquals(expected, result);
}
}
public void testBulkAdd() {
String expected = " Guess/Gold C1 C2 Marg. (Guess)\n" +
" C1 10 5 15\n" +
" C2 2 3 5\n" +
" Marg. (Gold) 12 8\n\n" +
" C1 = 1 prec=0.66667, recall=0.83333, spec=0.375, f1=0.74074\n" +
" C2 = 2 prec=0.6, recall=0.375, spec=0.83333, f1=0.46154\n";
ConfusionMatrix<Integer> conf = new ConfusionMatrix<Integer>(Locale.US);
conf.add(1,1, 10);
conf.add(1,2, 5);
conf.add(2,1,2);
conf.add(2,2,3);
String result = conf.printTable();
if (echo) {
System.err.println(result);
} else {
assertEquals(expected, result);
}
}
private static class BackwardsInteger implements Comparable<BackwardsInteger> {
private final int value;
public BackwardsInteger(int value) {
this.value = value;
}
public int compareTo(BackwardsInteger other) {
return other.value - this.value; // backwards
}
@Override
public int hashCode() {
return value;
}
public boolean equals(Object o) {
if (o == null || (!(o instanceof BackwardsInteger))) {
return false;
}
return (((BackwardsInteger) o).value == value);
}
@Override
public String toString() {
return Integer.toString(value);
}
}
public void testValueSort() {
String expected = " Guess/Gold 2 1 Marg. (Guess)\n" +
" 2 3 2 5\n" +
" 1 5 10 15\n" +
" Marg. (Gold) 8 12\n\n" +
" 2 prec=0.6, recall=0.375, spec=0.83333, f1=0.46154\n" +
" 1 prec=0.66667, recall=0.83333, spec=0.375, f1=0.74074\n";
BackwardsInteger one = new BackwardsInteger(1);
BackwardsInteger two = new BackwardsInteger(2);
ConfusionMatrix<BackwardsInteger> conf = new ConfusionMatrix<BackwardsInteger>(Locale.US);
conf.setUseRealLabels(true);
conf.add(one, one, 10);
conf.add(one, two, 5);
conf.add(two, one, 2);
conf.add(two, two, 3);
String result = conf.printTable();
if (echo) {
System.err.println(result);
} else {
assertEquals(expected, result);
}
}
public static void main(String[] args) {
ConfusionMatrixTest tester = new ConfusionMatrixTest(true);
System.out.println("Test 1");
tester.testBasic();
System.out.println("\nTest 2");
tester.testRealLabels();
System.out.println("\nTest 3");
tester.testBulkAdd();
System.out.println("\nTest 4");
tester.testValueSort();
}
}