/**
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* Licensed to the Apache Software Foundation (ASF) under one or more
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
/**
* The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
*
* The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default.
*
* See http://en.wikipedia.org/wiki/Confusion_matrix for background
*/
public class ConfusionMatrix {
private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
private final int[][] confusionMatrix;
private String defaultLabel = "unknown";
public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
this.defaultLabel = defaultLabel;
int i = 0;
for (String label : labels) {
labelMap.put(label, i++);
}
labelMap.put(defaultLabel, i);
}
public ConfusionMatrix(Matrix m) {
confusionMatrix = new int[m.numRows()][m.numRows()];
setMatrix(m);
}
public int[][] getConfusionMatrix() {
return confusionMatrix;
}
public Collection<String> getLabels() {
return Collections.unmodifiableCollection(labelMap.keySet());
}
public double getAccuracy(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
int correct = 0;
for (int i = 0; i < labelMap.size(); i++) {
labelTotal += confusionMatrix[labelId][i];
if (i == labelId) {
correct = confusionMatrix[labelId][i];
}
}
return 100.0 * correct / labelTotal;
}
public int getCorrect(String label) {
int labelId = labelMap.get(label);
return confusionMatrix[labelId][labelId];
}
public int getTotal(String label) {
int labelId = labelMap.get(label);
int labelTotal = 0;
for (int i = 0; i < labelMap.size(); i++) {
labelTotal += confusionMatrix[labelId][i];
}
return labelTotal;
}
public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
incrementCount(correctLabel, classifiedResult.getLabel());
}
public void addInstance(String correctLabel, String classifiedLabel) {
incrementCount(correctLabel, classifiedLabel);
}
public int getCount(String correctLabel, String classifiedLabel) {
Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not found: " + correctLabel);
Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
int correctId = labelMap.get(correctLabel);
int classifiedId = labelMap.get(classifiedLabel);
return confusionMatrix[correctId][classifiedId];
}
public void putCount(String correctLabel, String classifiedLabel, int count) {
Preconditions.checkArgument(labelMap.containsKey(correctLabel), "Label not found: " + correctLabel);
Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
int correctId = labelMap.get(correctLabel);
int classifiedId = labelMap.get(classifiedLabel);
confusionMatrix[correctId][classifiedId] = count;
}
public String getDefaultLabel() {
return defaultLabel;
}
public void incrementCount(String correctLabel, String classifiedLabel, int count) {
putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
}
public void incrementCount(String correctLabel, String classifiedLabel) {
incrementCount(correctLabel, classifiedLabel, 1);
}
public ConfusionMatrix merge(ConfusionMatrix b) {
Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
for (String correctLabel : this.labelMap.keySet()) {
for (String classifiedLabel : this.labelMap.keySet()) {
incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
}
}
return this;
}
public Matrix getMatrix() {
int length = confusionMatrix.length;
Matrix m = new DenseMatrix(length, length);
for (int r = 0; r < length; r++) {
for (int c = 0; c < length; c++) {
m.set(r, c, confusionMatrix[r][c]);
}
}
Map<String,Integer> labels = Maps.newHashMap();
for(Map.Entry<String, Integer> entry : labelMap.entrySet()) {
labels.put(entry.getKey(), entry.getValue());
}
m.setRowLabelBindings(labels);
m.setColumnLabelBindings(labels);
return m;
}
public void setMatrix(Matrix m) {
int length = confusionMatrix.length;
if (m.numRows() != m.numCols()) {
throw new IllegalArgumentException(
"ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square");
}
for (int r = 0; r < length; r++) {
for (int c = 0; c < length; c++) {
confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
}
}
Map<String,Integer> labels = m.getRowLabelBindings();
if (labels == null) {
labels = m.getColumnLabelBindings();
}
if (labels != null) {
String[] sorted = sortLabels(labels);
verifyLabels(length, sorted);
labelMap.clear();
for(int i = 0; i < length; i++) {
labelMap.put(sorted[i], i);
}
}
}
private static String[] sortLabels(Map<String,Integer> labels) {
String[] sorted = new String[labels.keySet().size()];
for(String label: labels.keySet()) {
Integer index = labels.get(label);
sorted[index] = label;
}
return sorted;
}
private static void verifyLabels(int length, String[] sorted) {
Preconditions.checkArgument(sorted.length == length, "One label, one row");
for (int i = 0; i < length; i++) {
if (sorted[i] == null) {
Preconditions.checkArgument(false, "One label, one row");
}
}
}
/**
* This is overloaded. toString() is not a formatted report you print for a manager :)
* Assume that if there are no default assignments, the default feature was not used
*/
@Override
public String toString() {
StringBuilder returnString = new StringBuilder(200);
returnString.append("=======================================================").append('\n');
returnString.append("Confusion Matrix\n");
returnString.append("-------------------------------------------------------").append('\n');
int unclassified = getTotal(defaultLabel);
for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
continue;
}
returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
}
returnString.append("<--Classified as").append('\n');
for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
continue;
}
String correctLabel = entry.getKey();
int labelTotal = 0;
for (String classifiedLabel : this.labelMap.keySet()) {
if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
continue;
}
returnString.append(
StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
labelTotal += getCount(correctLabel, classifiedLabel);
}
returnString.append(" | ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
.append(" = ").append(correctLabel).append('\n');
}
if (unclassified > 0) {
returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
}
returnString.append('\n');
return returnString.toString();
}
static String getSmallLabel(int i) {
int val = i;
StringBuilder returnString = new StringBuilder();
do {
int n = val % 26;
returnString.insert(0, (char) ('a' + n));
val /= 26;
} while (val > 0);
return returnString.toString();
}
}