package shared.tester; import java.util.HashMap; import java.util.Map; import shared.Instance; /** * A test metric to generate a confusion matrix. This metric expects the true labels * to be supplied at construction time, both to make sure the results are binned correctly * and to ensure clean output. * * @author Jesse Rosalia <https://github.com/theJenix> * @date 2013-03-05 */ public class ConfusionMatrixTestMetric implements TestMetric { /** * A matrix entry. This class holds an expected and actual instance * as a pair, and defines equals and hashCode for that pair. * * @author thejenix * */ private class MatrixEntry { private Instance expected; private Instance actual; public MatrixEntry(Instance expected, Instance actual) { this.expected = expected; this.actual = actual; } @Override public boolean equals(Object arg0) { if (!(arg0 instanceof MatrixEntry)) { return super.equals(arg0); } MatrixEntry me = (MatrixEntry) arg0; if (me.expected.size() != expected.size() || me.actual.size() != actual.size()) { return false; } //use the comparison class to test that these are equal Comparison c = new Comparison(expected, me.expected); Comparison d = new Comparison(actual, me.actual); return c.isAllCorrect() && d.isAllCorrect(); } @Override public int hashCode() { int hashCode = 0; for (int ii = 0; ii < expected.size(); ii++) { //scale the expected value, to provide separation between corresponding pairs // (e.g. a, b should be different from b, a) hashCode += 0x10000 * expected.getContinuous(ii); hashCode += actual.getContinuous(ii); } return hashCode; } } private Instance[] labels; private String[] labelStrs; private Instance nullLabel = new Instance(-1); private Map<MatrixEntry, Integer> matrix = new HashMap<MatrixEntry, Integer>(); /** * Construct the test metric with double valued labels. * * NOTE: these display with several significant figures...we may want to change this. * @param labels */ public ConfusionMatrixTestMetric(double[] labels) { this.labels = new Instance[labels.length]; this.labelStrs = new String[labels.length]; for (int i = 0; i < labels.length; i++) { this.labels [i] = new Instance(labels[i]); this.labelStrs[i] = Double.toString(labels[i]); } } /** * Construct the test metric with discrete (integer) labels. * * @param labels */ public ConfusionMatrixTestMetric(int[] labels) { this.labels = new Instance[labels.length]; this.labelStrs = new String[labels.length]; for (int i = 0; i < labels.length; i++) { this.labels [i] = new Instance(labels[i]); this.labelStrs[i] = Integer.toString(labels[i]); } } /** * Construct the test metric with boolean labels. * * @param labels */ public ConfusionMatrixTestMetric(boolean[] labels) { this.labels = new Instance[labels.length]; this.labelStrs = new String[labels.length]; for (int i = 0; i < labels.length; i++) { this.labels [i] = new Instance(labels[i]); //use "t" and "f" as the output string, for brevity this.labelStrs[i] = labels[i] ? "t" : "f"; } } @Override public void addResult(Instance expected, Instance actual) { Comparison c = new Comparison(expected, actual); for (int ii = 0; ii < expected.size(); ii++) { //find the actual value in the list of classes //...this makes sure we work with homogeneous label values, so our // matrix is readable. Instance found = findLabel(this.labels, actual); MatrixEntry e = new MatrixEntry(expected, found); if (matrix.containsKey(e)) { matrix.put(e, matrix.get(e) + 1); } else { matrix.put(e, 1); } } } /** * Find a label in the array of expected labels, using the Comparison class to validate correctness. * This is important for building the matrix, as it smooths out the noise (however small) that may be present * in the output of the classifier. * * @param labels * @param toFind * @return The corresponding label instance found in the array, or an object to represent the null label (i.e. not found) */ private Instance findLabel(Instance[] labels, Instance toFind) { Instance found = this.nullLabel; for (Instance label : labels) { Comparison c = new Comparison(label, toFind); if (c.isAllCorrect()) { found = label; break; } } return found; } @Override public void printResults() { System.out.println("Confusion Matrix:"); System.out.println(); //TODO: substitute letters instead of the acutal values, for the axes (like weka) //print the top labels for (String l : labelStrs) { System.out.print("\t"); System.out.print(l); } System.out.print("\t"); System.out.print("?"); for (int ii = 0; ii < labels.length; ii++) { Instance lr = labels[ii]; System.out.println(); System.out.print(labelStrs[ii]); for (Instance lc : labels) { Integer val = matrix.get(new MatrixEntry(lr, lc)); if (val == null) { val = 0; } System.out.print("\t"); System.out.print(val); } Integer val = matrix.get(new MatrixEntry(lr, this.nullLabel)); if (val == null) { val = 0; } System.out.print("\t"); System.out.print(val); } } }