package edu.stanford.nlp.ie;
import org.junit.Test;
import java.util.Arrays;
import java.util.HashSet;
import static org.junit.Assert.*;
/**
* A test for the {@link KBPRelationExtractor} base class.
* Also tests various nested classes.
*/
public class KBPRelationExtractorTest {
@SuppressWarnings("ArraysAsListWithZeroOrOneArgument")
@Test
public void testAccuracySimple() {
KBPRelationExtractor.Accuracy accuracy = new KBPRelationExtractor.Accuracy();
accuracy.predict(
new HashSet<>(Arrays.asList("a")),
new HashSet<>(Arrays.asList("a")));
accuracy.predict(
new HashSet<>(Arrays.asList("a")),
new HashSet<>(Arrays.asList()));
accuracy.predict(
new HashSet<>(Arrays.asList()),
new HashSet<>(Arrays.asList("b")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList()));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("b")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("b")));
assertEquals(0.5, accuracy.precision("a"), 1e-10);
assertEquals(1.0, accuracy.recall("a"), 1e-10);
assertEquals(2.0 * 1.0 * 0.5 / (1.0 + 0.5), accuracy.f1("a"), 1e-10);
assertEquals(2.0 / 3.0, accuracy.precision("b"), 1e-10);
assertEquals(2.0 / 3.0, accuracy.recall("b"), 1e-10);
assertEquals(3.0 / 5.0, accuracy.precisionMicro(), 1e-10);
assertEquals(7.0 / 12.0, accuracy.precisionMacro(), 1e-10);
assertEquals(3.0 / 4.0, accuracy.recallMicro(), 1e-10);
assertEquals(5.0 / 6.0, accuracy.recallMacro(), 1e-10);
}
@SuppressWarnings("ArraysAsListWithZeroOrOneArgument")
@Test
public void testAccuracyNoRelation() {
KBPRelationExtractor.Accuracy accuracy = new KBPRelationExtractor.Accuracy();
accuracy.predict(
new HashSet<>(Arrays.asList("a")),
new HashSet<>(Arrays.asList("a")));
accuracy.predict(
new HashSet<>(Arrays.asList("a")),
new HashSet<>(Arrays.asList("no_relation")));
accuracy.predict(
new HashSet<>(Arrays.asList("no_relation")),
new HashSet<>(Arrays.asList("b")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("no_relation")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("b")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("b")));
assertEquals(0.5, accuracy.precision("a"), 1e-10);
assertEquals(1.0, accuracy.recall("a"), 1e-10);
assertEquals(2.0 * 1.0 * 0.5 / (1.0 + 0.5), accuracy.f1("a"), 1e-10);
assertEquals(2.0 / 3.0, accuracy.precision("b"), 1e-10);
assertEquals(2.0 / 3.0, accuracy.recall("b"), 1e-10);
assertEquals(3.0 / 5.0, accuracy.precisionMicro(), 1e-10);
assertEquals(7.0 / 12.0, accuracy.precisionMacro(), 1e-10);
assertEquals(3.0 / 4.0, accuracy.recallMicro(), 1e-10);
assertEquals(5.0 / 6.0, accuracy.recallMacro(), 1e-10);
}
@SuppressWarnings("ArraysAsListWithZeroOrOneArgument")
@Test
public void testAccuracyTrueNegatives() {
KBPRelationExtractor.Accuracy accuracy = new KBPRelationExtractor.Accuracy();
accuracy.predict(
new HashSet<>(Arrays.asList("a")),
new HashSet<>(Arrays.asList("a")));
accuracy.predict(
new HashSet<>(Arrays.asList("a")),
new HashSet<>(Arrays.asList("no_relation")));
accuracy.predict(
new HashSet<>(Arrays.asList("no_relation")),
new HashSet<>(Arrays.asList("b")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("no_relation")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("b")));
accuracy.predict(
new HashSet<>(Arrays.asList("b")),
new HashSet<>(Arrays.asList("b")));
accuracy.predict(
new HashSet<>(Arrays.asList("no_relation")),
new HashSet<>(Arrays.asList("no_relation")));
accuracy.predict(
new HashSet<>(Arrays.asList("no_relation")),
new HashSet<>(Arrays.asList("no_relation")));
accuracy.predict(
new HashSet<>(Arrays.asList("no_relation")),
new HashSet<>(Arrays.asList("no_relation")));
assertEquals(0.5, accuracy.precision("a"), 1e-10);
assertEquals(1.0, accuracy.recall("a"), 1e-10);
assertEquals(2.0 * 1.0 * 0.5 / (1.0 + 0.5), accuracy.f1("a"), 1e-10);
assertEquals(2.0 / 3.0, accuracy.precision("b"), 1e-10);
assertEquals(2.0 / 3.0, accuracy.recall("b"), 1e-10);
assertEquals(3.0 / 5.0, accuracy.precisionMicro(), 1e-10);
assertEquals(7.0 / 12.0, accuracy.precisionMacro(), 1e-10);
assertEquals(3.0 / 4.0, accuracy.recallMicro(), 1e-10);
assertEquals(5.0 / 6.0, accuracy.recallMacro(), 1e-10);
}
}