package com.github.lwhite1.tablesaw.api.ml.classification; import com.github.lwhite1.tablesaw.api.CategoryColumn; import com.github.lwhite1.tablesaw.api.IntColumn; import com.google.common.collect.Table; import com.google.common.collect.TreeBasedTable; import java.util.ArrayList; import java.util.List; import java.util.Set; import java.util.SortedMap; import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; /** * A confusion matrix is used to measure the accuracy of a classifier by counting the number of correct and * incorrect values produced when testing the classifier such that the counts are made for every combination of * correct and incorrect classification */ public class StandardConfusionMatrix implements ConfusionMatrix { private final Table<Integer, Integer, Integer> table = TreeBasedTable.create(); private SortedMap<Integer, Object> labels = new TreeMap<>(); public StandardConfusionMatrix(SortedSet<Object> labels) { int i = 0; for (Object object : labels) { this.labels.put(i, object); i++; } } public void increment(Integer predicted, Integer actual) { Integer v = table.get(predicted, actual); if (v == null) { table.put(predicted, actual, 1); } else { table.put(predicted, actual, v + 1); } } @Override public String toString() { return toTable().print(); } public com.github.lwhite1.tablesaw.api.Table toTable() { com.github.lwhite1.tablesaw.api.Table t = com.github.lwhite1.tablesaw.api.Table.create("Confusion Matrix"); t.addColumn(CategoryColumn.create("")); // make a set of all the values needed, from the prediction set or the actual set TreeSet<Integer> allValues = new TreeSet<>(); allValues.addAll(table.columnKeySet()); allValues.addAll(table.rowKeySet()); for (Integer comparable : allValues) { t.addColumn(IntColumn.create(String.valueOf(labels.get(comparable)))); t.column(0).addCell("Predicted " + labels.get(comparable)); } // put them in a list so they can be accessed by index number List<Comparable> valuesList = new ArrayList<>(allValues); int n = 0; for (int r = 0; r < valuesList.size(); r++) { for (int c = 0; c < valuesList.size(); c++) { Integer value = table.get(valuesList.get(r), valuesList.get(c)); if (value == null) { t.intColumn(c + 1).add(0); } else { t.intColumn(c + 1).add(value); n = n + value; } } } t.column(0).setName("n = " + n); for (int c = 1; c <= valuesList.size(); c++) { t.column(c).setName("Actual " + labels.get(c - 1)); } return t; } public double accuracy() { Set<Table.Cell<Integer, Integer, Integer>> cellSet = table.cellSet(); int hits = 0; int misses = 0; for (Table.Cell cell : cellSet) { if (cell.getRowKey().equals(cell.getColumnKey())) { hits = hits + (int) cell.getValue(); } else { misses = misses + (int) cell.getValue(); } } return hits / ((hits + misses) * 1.0); } }