/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.classifier;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import com.google.common.collect.Lists;
/**
* Export a ConfusionMatrix in various text formats:
* ToString version
* Grayscale HTML table
* Summary HTML table
* Table of counts
* all with optional HTML wrappers
*
* Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair
*
* Intended to consume ConfusionMatrix SequenceFile output by Bayes
* TestClassifier class
*/
public final class ConfusionMatrixDumper extends AbstractJob {
// HTML wrapper - default CSS
private static final String HEADER = "<html>" +
"<head>\n" +
"<title>TITLE</title>\n" +
"</head>" +
"<body>\n" +
"<style type='text/css'> \n" +
"table\n" +
"{\n" +
"border:3px solid black; text-align:left;\n" +
"}\n" +
"th.normalHeader\n" +
"{\n" +
"border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n" +
"}\n" +
"th.tallHeader\n" +
"{\n" +
"border:1px solid black;border-collapse:collapse;text-align:center;background-color:white; height:6em\n" +
"}\n" +
"tr.label\n" +
"{\n" +
"border:1px solid black;border-collapse:collapse;text-align:center;background-color:white\n" +
"}\n" +
"tr.row\n" +
"{\n" +
"border:1px solid gray;text-align:center;background-color:snow\n" +
"}\n" +
"td\n" +
"{\n" +
"min-width:2em\n" +
"}\n" +
"td.cell\n" +
"{\n" +
"border:1px solid black;text-align:right;background-color:snow\n" +
"}\n" +
"td.empty\n" +
"{\n" +
"border:0px;text-align:right;background-color:snow\n" +
"}\n" +
"td.white\n" +
"{\n" +
"border:0px solid black;text-align:right;background-color:white\n" +
"}\n" +
"td.black\n" +
"{\n" +
"border:0px solid red;text-align:right;background-color:black\n" +
"}\n" +
"td.gray1\n" +
"{\n" +
"border:0px solid green;text-align:right; background-color:LightGray\n" +
"}\n" +
"td.gray2\n" +
"{\n" +
"border:0px solid blue;text-align:right;background-color:gray\n" +
"}\n" +
"td.gray3\n" +
"{\n" +
"border:0px solid red;text-align:right;background-color:DarkGray\n" +
"}\n" +
"th" +
"{\n" +
" text-align: center;\n" +
" vertical-align: bottom;\n" +
" padding-bottom: 3px;\n" +
" padding-left: 5px;\n" +
" padding-right: 5px;\n" +
"}\n" +
" .verticalText\n" +
" {\n" +
" text-align: center;\n" +
" vertical-align: middle;\n" +
" width: 20px;\n" +
" margin: 0px;\n" +
" padding: 0px;\n" +
" padding-left: 3px;\n" +
" padding-right: 3px;\n" +
" padding-top: 10px;\n" +
" white-space: nowrap;\n" +
" -webkit-transform: rotate(-90deg); \n" +
" -moz-transform: rotate(-90deg); \n" +
" };\n" +
"</style>\n";
private static final String FOOTER = "</html></body>";
// CSS style names.
private static final String CSS_TABLE = "table";
private static final String CSS_LABEL = "label";
private static final String CSS_TALL_HEADER = "tall";
private static final String CSS_VERTICAL = "verticalText";
private static final String CSS_CELL = "cell";
private static final String CSS_EMPTY = "empty";
private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3", "black"};
private ConfusionMatrixDumper() {}
public static void main(String[] args) throws Exception {
ToolRunner.run(new ConfusionMatrixDumper(), args);
}
@Override
public int run(String[] args) throws IOException {
addInputOption();
addOption("output", "o", "Output path", null); // AbstractJob output feature requires param
addOption(DefaultOptionCreator.overwriteOption().create());
addFlag("html", null, "Create complete HTML page");
addFlag("text", null, "Dump simple text");
Map<String, String> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
return -1;
}
Path inputPath = getInputPath();
String outputFile = parsedArgs.containsKey("--output") ? parsedArgs.get("--output") : null;
boolean text = parsedArgs.containsKey("--text");
boolean wrapHtml = parsedArgs.containsKey("--html");
PrintStream out = getPrintStream(outputFile);
if (text) {
exportText(inputPath, out);
} else {
exportTable(inputPath, out, wrapHtml);
}
out.flush();
if (out != System.out) {
out.close();
}
return 0;
}
private static void exportText(Path inputPath, PrintStream out) throws IOException {
MatrixWritable mw = new MatrixWritable();
Text key = new Text();
readSeqFile(inputPath, key, mw);
Matrix m = mw.get();
ConfusionMatrix cm = new ConfusionMatrix(m);
out.println(cm.toString());
}
private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml) throws IOException {
MatrixWritable mw = new MatrixWritable();
Text key = new Text();
readSeqFile(inputPath, key, mw);
String fileName = inputPath.getName();
fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length());
Matrix m = mw.get();
ConfusionMatrix cm = new ConfusionMatrix(m);
if (wrapHtml) {
printHeader(out, fileName);
}
out.println("<p/>");
printSummaryTable(cm, out);
out.println("<p/>");
printGrayTable(cm, out);
out.println("<p/>");
printCountsTable(cm, out);
out.println("<p/>");
printTextInBox(cm, out);
out.println("<p/>");
if (wrapHtml) {
printFooter(out);
}
}
private static List<String> stripDefault(ConfusionMatrix cm) {
List<String> stripped = Lists.newArrayList(cm.getLabels().iterator());
String defaultLabel = cm.getDefaultLabel();
int unclassified = cm.getTotal(defaultLabel);
if (unclassified > 0) {
return stripped;
}
stripped.remove(defaultLabel);
return stripped;
}
// TODO: test - this should work with HDFS files
private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException {
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(conf);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
reader.next(key, m);
}
// TODO: test - this might not work with HDFS files?
// after all, it does no seeks
private static PrintStream getPrintStream(String outputFilename) throws IOException {
if (outputFilename != null) {
File outputFile = new File(outputFilename);
if (outputFile.exists()) {
outputFile.delete();
}
outputFile.createNewFile();
OutputStream os = new FileOutputStream(outputFile);
return new PrintStream(os);
} else {
return System.out;
}
}
private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) {
Iterator<String> iter = cm.getLabels().iterator();
int count = 0;
while(iter.hasNext()) {
count += cm.getCount(rowLabel, iter.next());
}
return count;
}
// HTML generator code
private static void printTextInBox(ConfusionMatrix cm, PrintStream out) {
out.println("<div style='width:90%;overflow:scroll;'>");
out.println("<pre>");
out.println(cm.toString());
out.println("</pre>");
out.println("</div>");
}
public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) {
format("<table class='%s'>\n", out, CSS_TABLE);
format("<tr class='%s'>", out, CSS_LABEL);
out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>");
out.println("</tr>");
List<String> labels = stripDefault(cm);
for(String label: labels) {
printSummaryRow(cm, out, label);
}
out.println("</table>");
}
private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label) {
format("<tr class='%s'>", out, CSS_CELL);
int correct = cm.getCorrect(label);
double accuracy = cm.getAccuracy(label);
int count = getCount(cm, label);
format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>",
out, CSS_CELL, label, count, correct, (int) Math.round(accuracy));
out.println("</tr>");
}
private static int getCount(ConfusionMatrix cm, String label) {
int count = 0;
for (String s : cm.getLabels()) {
count += cm.getCount(label, s);
}
return count;
}
public static void printGrayTable(ConfusionMatrix cm, PrintStream out) {
format("<table class='%s'>\n", out, CSS_TABLE);
printCountsHeader(cm, out, true);
printGrayRows(cm, out);
out.println("</table>");
}
/**
* Print each value in a four-value grayscale based on count/max.
* Gives a mostly white matrix with grays in misclassified, and black in diagonal.
* TODO: Using the sqrt(count/max) as the rating is more stringent
*/
private static void printGrayRows(ConfusionMatrix cm, PrintStream out) {
List<String> labels = stripDefault(cm);
for (String label: labels) {
printGrayRow(cm, out, labels, label);
}
}
private static void printGrayRow(ConfusionMatrix cm, PrintStream out, Iterable<String> labels, String rowLabel) {
format("<tr class='%s'>", out, CSS_LABEL);
format("<td>%s</td>", out, rowLabel);
int total = getLabelTotal(cm, rowLabel);
for (String columnLabel: labels) {
printGrayCell(cm, out, total, rowLabel, columnLabel);
}
out.println("</tr>");
}
// assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs
// assign black to count = total, meaning complete success
// alternative rating is to use sqrt(total) instead of total - this is more drastic
private static void printGrayCell(ConfusionMatrix cm,
PrintStream out,
int total,
String rowLabel,
String columnLabel) {
int count = cm.getCount(rowLabel, columnLabel);
if (count == 0) {
out.format("<td class='%s'/>", CSS_EMPTY);
} else {
// 0 is white, full is black, everything else gray
int rating = (int) ((count/ (double) total) * 4);
String css = CSS_GRAY_CELLS[rating];
format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, count);
}
}
public static void printCountsTable(ConfusionMatrix cm, PrintStream out) {
format("<table class='%s'>\n", out, CSS_TABLE);
printCountsHeader(cm, out, false);
printCountsRows(cm, out);
out.println("</table>");
}
private static void printCountsRows(ConfusionMatrix cm, PrintStream out) {
List<String> labels = stripDefault(cm);
for(String label: labels) {
printCountsRow(cm, out, labels, label);
}
}
private static void printCountsRow(ConfusionMatrix cm, PrintStream out, Iterable<String> labels, String rowLabel) {
out.println("<tr>");
format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel);
for(String columnLabel: labels) {
printCountsCell(cm, out, rowLabel, columnLabel);
}
out.println("</tr>");
}
private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel, String columnLabel) {
int count = cm.getCount(rowLabel, columnLabel);
String s = count == 0 ? "" : Integer.toString(count);
format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, s);
}
private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean vertical) {
List<String> labels = stripDefault(cm);
int longest = getLongestHeader(labels);
if (vertical) {
// do vertical - rotation is a bitch
out.format("<tr class='%s' style='height:%dem'><th> </th>\n", CSS_TALL_HEADER, longest/2);
for(String label: labels) {
out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label);
}
out.println("</tr>");
} else {
// header - empty cell in upper left
out.format("<tr class='%s'><td class='%s'></td>\n", CSS_TABLE, CSS_LABEL);
for(String label: labels) {
out.format("<td>%s</td>", label);
}
out.format("</tr>");
}
}
private static int getLongestHeader(Iterable<String> labels) {
int max = 0;
for (String label: labels) {
max = Math.max(label.length(), max);
}
return max;
}
private static void format(String format, PrintStream out, Object ... args) {
String format2 = String.format(format, args);
out.println(format2);
}
public static void printHeader(PrintStream out, CharSequence title) {
out.println(HEADER.replace("TITLE", title));
}
public static void printFooter(PrintStream out) {
out.println(FOOTER);
}
}