package hex;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.FileUtils;
import water.util.VecUtils;
import static water.util.FrameUtils.parseFrame;
import java.io.IOException;
import java.util.Arrays;
public class ConfusionMatrixTest extends TestUtil {
@BeforeClass
public static void stall() { stall_till_cloudsize(5); }
final boolean debug = false;
@Test
public void testIdenticalVectors() {
try {
Scope.enter();
simpleCMTest(
"smalldata/junit/cm/v1.csv",
"smalldata/junit/cm/v1.csv",
ar("A", "B", "C"),
ar("A", "B", "C"),
ar("A", "B", "C"),
ard(ard(2, 0, 0),
ard(0, 2, 0),
ard(0, 0, 1)),
debug);
} finally {
Scope.exit();
}
}
@Test
public void testVectorAlignment() {
simpleCMTest(
"smalldata/junit/cm/v1.csv",
"smalldata/junit/cm/v2.csv",
ar("A", "B", "C"),
ar("A", "B", "C"),
ar("A", "B", "C"),
ard( ard(1, 1, 0),
ard(0, 1, 1),
ard(0, 0, 1)
),
debug);
}
/** Negative test testing expected exception if two vectors
* of different lengths are provided.
*/
@Test(expected = IllegalArgumentException.class)
public void testDifferentLengthVectors() {
simpleCMTest(
"smalldata/junit/cm/v1.csv",
"smalldata/junit/cm/v3.csv",
ar("A", "B", "C"),
ar("A", "B", "C"),
ar("A", "B", "C"),
ard( ard(1, 1, 0),
ard(0, 1, 1),
ard(0, 0, 1)
),
debug);
}
@Test
public void testDifferentDomains() {
simpleCMTest(
"smalldata/junit/cm/v1.csv",
"smalldata/junit/cm/v4.csv",
ar("A", "B", "C"),
ar("B", "C"),
ar("A", "B", "C"),
ard( ard(0, 2, 0),
ard(0, 0, 2),
ard(0, 0, 1)
),
debug);
simpleCMTest(
"smalldata/junit/cm/v2.csv",
"smalldata/junit/cm/v4.csv",
ar("A", "B", "C"),
ar("B", "C"),
ar("A", "B", "C"),
ard( ard(0, 1, 0),
ard(0, 1, 1),
ard(0, 0, 2)
),
debug);
}
@Test
public void testSimpleNumericVectors() {
simpleCMTest(
"smalldata/junit/cm/v1n.csv",
"smalldata/junit/cm/v1n.csv",
ar("0", "1", "2"),
ar("0", "1", "2"),
ar("0", "1", "2"),
ard( ard(2, 0, 0),
ard(0, 2, 0),
ard(0, 0, 1)
),
debug);
simpleCMTest(
"smalldata/junit/cm/v1n.csv",
"smalldata/junit/cm/v2n.csv",
ar("0", "1", "2"),
ar("0", "1", "2"),
ar("0", "1", "2"),
ard( ard(1, 1, 0),
ard(0, 1, 1),
ard(0, 0, 1)
),
debug);
}
@Test
public void testDifferentDomainsNumericVectors() {
simpleCMTest(
"smalldata/junit/cm/v1n.csv",
"smalldata/junit/cm/v4n.csv",
ar("0", "1", "2"),
ar("1", "2"),
ar("0", "1", "2"),
ard( ard(0, 2, 0),
ard(0, 0, 2),
ard(0, 0, 1)
),
debug);
simpleCMTest(
"smalldata/junit/cm/v2n.csv",
"smalldata/junit/cm/v4n.csv",
ar("0", "1", "2"),
ar("1", "2"),
ar("0", "1", "2"),
ard( ard(0, 1, 0),
ard(0, 1, 1),
ard(0, 0, 2)
),
debug);
}
/** Test for PUB-216:
* The case when vector domain is set to a value (0~A, 1~B, 2~C), but actual values stored in
* vector references only a subset of domain (1~B, 2~C). The TransfVec was using minimum from
* vector (i.e., value 1) to compute transformation but minimum was wrong since it should be 0. */
@Test public void testBadModelPrect() {
simpleCMTest(
ArrayUtils.frame("v1", vec(ar("A", "B", "C"), ari(0, 0, 1, 1, 2))),
ArrayUtils.frame("v1", vec(ar("A", "B", "C"), ari(1, 1, 2, 2, 2))),
ar("A","B","C"),
ar("A","B","C"),
ar("A","B","C"),
ard( ard(0, 2, 0),
ard(0, 0, 2),
ard(0, 0, 1)
),
debug);
}
@Test public void testBadModelPrect2() {
simpleCMTest(
ArrayUtils.frame("v1", vec(ar("-1", "0", "1"), ari(0, 0, 1, 1, 2))),
ArrayUtils.frame("v1", vec(ar("0", "1"), ari(0, 0, 1, 1, 1))),
ar("-1", "0", "1"),
ar("0", "1"),
ar("-1", "0", "1"),
ard(ard(0, 2, 0),
ard(0, 0, 2),
ard(0, 0, 1)
),
debug);
}
private void simpleCMTest(String f1, String f2, String[] expectedActualDomain, String[] expectedPredictDomain, String[] expectedDomain, double[][] expectedCM, boolean debug) {
try {
Frame v1 = parseFrame(Key.make("v1.hex"), FileUtils.getFile(f1));
Frame v2 = parseFrame(Key.make("v2.hex"), FileUtils.getFile(f2));
if (!v1.isCompatible(v2)) {
Frame old = null;
v2 = new Frame(v1.makeCompatible(old = v2));
old.delete();
}
simpleCMTest(v1, v2, expectedActualDomain, expectedPredictDomain, expectedDomain, expectedCM, debug);
} catch (IOException e) {
e.printStackTrace();
}
}
/** Delete v1, v2 after potential modifying operations during processing: categoricals and/or train/test adaptation. */
private void simpleCMTest(Frame v1, Frame v2, String[] actualDomain, String[] predictedDomain, String[] expectedDomain, double[][] expectedCM, boolean debug) {
Scope.enter();
try {
ConfusionMatrix cm = ConfusionMatrix.buildCM(VecUtils.toCategoricalVec(v1.vecs()[0]), VecUtils.toCategoricalVec(v2.vecs()[0]));
// -- DEBUG --
if (debug) {
System.err.println("actual : " + Arrays.toString(actualDomain));
System.err.println("predicted : " + Arrays.toString(predictedDomain));
System.err.println("CM domain : " + Arrays.toString(cm._domain));
System.err.println("expected CM domain: " + Arrays.toString(expectedDomain) + "\n");
for (int i=0; i<cm._cm.length; i++)
System.err.println(Arrays.toString(cm._cm[i]));
System.err.println("");
System.err.println(cm.toASCII());
}
// -- -- --
assertCMEqual(expectedDomain, expectedCM, cm);
} finally {
if (v1 != null) v1.delete();
if (v2 != null) v2.delete();
Scope.exit();
}
}
private void assertCMEqual(String[] expectedDomain, double[][] expectedCM, ConfusionMatrix actualCM) {
Assert.assertArrayEquals("Expected domain differs", expectedDomain, actualCM._domain);
double[][] acm = actualCM._cm;
Assert.assertEquals("CM dimension differs", expectedCM.length, acm.length);
for (int i=0; i < acm.length; i++) Assert.assertArrayEquals("CM row " +i+" differs!", expectedCM[i], acm[i],1e-10);
}
}