package hex;
import water.Iced;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;
import java.util.Arrays;
public class ConfusionMatrix extends Iced {
private TwoDimTable _table;
public final double[][] _cm; // [actual][predicted], typed as double because of observation weights (which can be doubles)
public final String[] _domain;
public static final int MAX_CM_CLASSES = 1000;
/**
* Constructor for Confusion Matrix
* @param value 2D square matrix with co-occurrence counts for actual vs predicted class membership
* @param domain class labels (unified domain between actual and predicted class labels)
*/
public ConfusionMatrix(double[][] value, String[] domain) { _cm = value; _domain = domain; }
/** Build the CM data from the actuals and predictions, using the default
* threshold. Print to Log.info if the number of classes is below the
* print_threshold. Actuals might have extra levels not trained on (hence
* never predicted). Actuals with NAs are not scored, and their predictions
* ignored. */
public static ConfusionMatrix buildCM(Vec actuals, Vec predictions) {
if (!actuals.isCategorical()) throw new IllegalArgumentException("actuals must be categorical.");
if (!predictions.isCategorical()) throw new IllegalArgumentException("predictions must be categorical.");
Scope.enter();
try {
Vec adapted = predictions.adaptTo(actuals.domain());
int len = actuals.domain().length;
CMBuilder cm = new CMBuilder(len).doAll(actuals, adapted);
return new ConfusionMatrix(cm._arr, actuals.domain());
} finally {
Scope.exit();
}
}
private static class CMBuilder extends MRTask<CMBuilder> {
final int _len;
double _arr[/*actuals*/][/*predicted*/];
CMBuilder(int len) { _len = len; }
@Override public void map( Chunk ca, Chunk cp ) {
if (_len > MAX_CM_CLASSES) return;
// After adapting frames, the Actuals have all the levels in the
// prediction results, plus any extras the model was never trained on.
// i.e., Actual levels are at least as big as the predicted levels.
_arr = new double[_len][_len];
for( int i=0; i < ca._len; i++ )
if( !ca.isNA(i) )
_arr[(int)ca.at8(i)][(int)cp.at8(i)]++;
}
@Override public void reduce( CMBuilder cm ) {
ArrayUtils.add(_arr,cm._arr);
}
}
public void add(int i, int j) { _cm[i][j]++; }
public final int size() { return _domain.length; }
boolean tooLarge() { return size() > MAX_CM_CLASSES; }
public final double mean_per_class_error() {
if(tooLarge())throw new UnsupportedOperationException("mean per class error cannot be computed: too many classes");
double err = 0;
for( int d = 0; d < _cm.length; ++d )
err += class_error(d); //can be 0 if no actuals, but we're still dividing by the total count of classes
return err / _cm.length;
}
// mean(accuracy) = mean(1-error) = 1-mean(error)
public final double mean_per_class_accuracy() {
return 1-mean_per_class_error();
}
public final double class_error(int c) {
if(tooLarge())throw new UnsupportedOperationException("class errors cannot be computed: too many classes");
double s = ArrayUtils.sum(_cm[c]);
if( s == 0 ) return 0.0; // Either 0 or NaN, but 0 is nicer
return (s - _cm[c][c]) / s;
}
public double total_rows() {
double n = 0;
for (double[] a_arr : _cm)
n += ArrayUtils.sum(a_arr);
return n;
}
public void add(ConfusionMatrix other) {
if (_cm != null && other._cm != null)
ArrayUtils.add(_cm, other._cm);
}
/**
* @return overall classification error
*/
public double err() {
if(tooLarge())throw new UnsupportedOperationException("error cannot be computed: too many classes");
double n = total_rows();
double err = n;
for( int d = 0; d < _cm.length; ++d )
err -= _cm[d][d];
return err / n;
}
public double err_count() {
if(tooLarge())throw new UnsupportedOperationException("error count cannot be computed: too many classes");
double err = total_rows();
for( int d = 0; d < _cm.length; ++d )
err -= _cm[d][d];
assert(err >= 0);
return err;
}
/**
* The percentage of predictions that are correct.
*/
public double accuracy() { return 1-err(); }
/**
* The percentage of negative labeled instances that were predicted as negative.
* @return TNR / Specificity
*/
public double specificity() {
if(!isBinary())throw new UnsupportedOperationException("specificity is only implemented for 2 class problems.");
if(tooLarge())throw new UnsupportedOperationException("specificity cannot be computed: too many classes");
double tn = _cm[0][0];
double fp = _cm[0][1];
return tn / (tn + fp);
}
/**
* The percentage of positive labeled instances that were predicted as positive.
* @return Recall / TPR / Sensitivity
*/
public double recall() {
if(!isBinary())throw new UnsupportedOperationException("recall is only implemented for 2 class problems.");
if(tooLarge())throw new UnsupportedOperationException("recall cannot be computed: too many classes");
double tp = _cm[1][1];
double fn = _cm[1][0];
return tp / (tp + fn);
}
/**
* The percentage of positive predictions that are correct.
* @return Precision
*/
public double precision() {
if(!isBinary())throw new UnsupportedOperationException("precision is only implemented for 2 class problems.");
if(tooLarge())throw new UnsupportedOperationException("precision cannot be computed: too many classes");
double tp = _cm[1][1];
double fp = _cm[0][1];
return tp / (tp + fp);
}
/**
* The Matthews Correlation Coefficient, takes true negatives into account in contrast to F-Score
* See <a href="http://en.wikipedia.org/wiki/Matthews_correlation_coefficient">MCC</a>
* MCC = Correlation between observed and predicted binary classification
* @return mcc ranges from -1 (total disagreement) ... 0 (no better than random) ... 1 (perfect)
*/
public double mcc() {
if(!isBinary())throw new UnsupportedOperationException("mcc is only implemented for 2 class problems.");
if(tooLarge())throw new UnsupportedOperationException("mcc cannot be computed: too many classes");
double tn = _cm[0][0];
double fp = _cm[0][1];
double tp = _cm[1][1];
double fn = _cm[1][0];
return (tp*tn - fp*fn)/Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn));
}
/**
* The maximum per-class error
* @return max[classErr(i)]
*/
public double max_per_class_error() {
int n = nclasses();
if(n == 0)throw new UnsupportedOperationException("max per class error is only defined for classification problems");
if(tooLarge())throw new UnsupportedOperationException("max per class error cannot be computed: too many classes");
double res = class_error(0);
for(int i = 1; i < n; ++i)
res = Math.max(res, class_error(i));
return res;
}
public final int nclasses(){return _domain == null ? 0: _domain.length;}
public final boolean isBinary(){return nclasses() == 2;}
/**
* Returns the F-measure which combines precision and recall. <br>
* C.f. end of http://en.wikipedia.org/wiki/Precision_and_recall.
*/
public double f1() {
final double precision = precision();
final double recall = recall();
return 2. * (precision * recall) / (precision + recall);
}
/**
* Returns the F-measure which combines precision and recall and weights recall higher than precision. <br>
* See <a href="http://en.wikipedia.org/wiki/F1_score.">F1_score</a>
*/
public double f2() {
final double precision = precision();
final double recall = recall();
return 5. * (precision * recall) / (4. * precision + recall);
}
/**
* Returns the F-measure which combines precision and recall and weights precision higher than recall. <br>
* See <a href="http://en.wikipedia.org/wiki/F1_score.">F1_score</a>
*/
public double f0point5() {
final double precision = precision();
final double recall = recall();
return 1.25 * (precision * recall) / (.25 * precision + recall);
}
@Override public String toString() {
StringBuilder sb = new StringBuilder();
for( double[] r : _cm)
sb.append(Arrays.toString(r)).append('\n');
return sb.toString();
}
private static String[] createConfusionMatrixHeader( double xs[], String ds[] ) {
String ss[] = new String[xs.length]; // the same length
for( int i=0; i<xs.length; i++ )
if( xs[i] >= 0 || (ds[i] != null && ds[i].length() > 0) && !Double.toString(i).equals(ds[i]) )
ss[i] = ds[i];
if( ds.length == xs.length-1 && xs[xs.length-1] > 0 )
ss[xs.length-1] = "NA";
return ss;
}
public String toASCII() { return table() == null ? "" : _table.toString(); }
/** Convert this ConfusionMatrix into a fully annotated TwoDimTable
* @return TwoDimTable */
public TwoDimTable table() { return _table == null ? (_table=toTable()) : _table; }
// Do the work making a TwoDimTable
private TwoDimTable toTable() {
if (tooLarge()) return null;
if (_cm == null || _domain == null) return null;
for( double cm[] : _cm ) assert(_cm.length == cm.length);
// Sum up predicted & actuals
double acts [] = new double[_cm.length];
double preds[] = new double[_cm[0].length];
boolean isInt = true;
for( int a=0; a< _cm.length; a++ ) {
double sum=0;
for( int p=0; p< _cm[a].length; p++ ) {
sum += _cm[a][p];
preds[p] += _cm[a][p];
isInt &= (_cm[a][p] == (long)_cm[a][p]);
}
acts[a] = sum;
}
String adomain[] = createConfusionMatrixHeader(acts , _domain);
String pdomain[] = createConfusionMatrixHeader(preds, _domain);
assert adomain.length == pdomain.length : "The confusion matrix should have the same length for both directions.";
String[] rowHeader = Arrays.copyOf(adomain,adomain.length+1);
rowHeader[adomain.length] = "Totals";
String[] colHeader = Arrays.copyOf(pdomain,pdomain.length+2);
colHeader[colHeader.length-2] = "Error";
colHeader[colHeader.length-1] = "Rate";
String[] colType = new String[colHeader.length];
String[] colFormat = new String[colHeader.length];
for (int i=0; i<colFormat.length-1; ++i) {
colType[i] = isInt ? "long":"double";
colFormat[i] = isInt ? "%d":"%.2f";
}
colType[colFormat.length-2] = "double";
colFormat[colFormat.length-2] = "%.4f";
colType[colFormat.length-1] = "string";
// pass 1: compute width of last column
double terr = 0;
int width = 0;
for (int a = 0; a < _cm.length; a++) {
if (adomain[a] == null) continue;
double correct = 0;
for (int p = 0; p < pdomain.length; p++) {
if (pdomain[p] == null) continue;
boolean onDiag = adomain[a].equals(pdomain[p]);
if (onDiag) correct = _cm[a][p];
}
double err = acts[a] - correct;
terr += err;
width = isInt ?
Math.max(width, String.format("%,d / %,d", (long)err, (long)acts[a]).length()):
Math.max(width, String.format("%.4f / %.4f", err, acts[a]).length());
}
double nrows = 0;
for (double n : acts) nrows += n;
width = isInt?
Math.max(width, String.format("%,d / %,d", (long)terr, (long)nrows).length()):
Math.max(width, String.format("%.4f / %.4f", terr, nrows).length());
// set format width
colFormat[colFormat.length-1] = "= %" + width + "s";
TwoDimTable table = new TwoDimTable("Confusion Matrix", "vertical: actual; across: predicted", rowHeader, colHeader, colType, colFormat, null);
// Main CM Body
for (int a = 0; a < _cm.length; a++) {
if (adomain[a] == null) continue;
double correct = 0;
for (int p = 0; p < pdomain.length; p++) {
if (pdomain[p] == null) continue;
boolean onDiag = adomain[a].equals(pdomain[p]);
if (onDiag) correct = _cm[a][p];
if (isInt)
table.set(a, p, (long)_cm[a][p]);
else
table.set(a, p, _cm[a][p]);
}
double err = acts[a] - correct;
table.set(a, pdomain.length, err / acts[a]);
table.set(a, pdomain.length + 1,
isInt ? String.format("%,d / %,d", (long)err, (long)acts[a]):
String.format("%.4f / %.4f", err, acts[a])
);
}
// Last row of CM
for (int p = 0; p < pdomain.length; p++) {
if (pdomain[p] == null) continue;
if (isInt)
table.set(adomain.length, p, (long)preds[p]);
else
table.set(adomain.length, p, preds[p]);
}
table.set(adomain.length, pdomain.length, (float) terr / nrows);
table.set(adomain.length, pdomain.length + 1,
isInt ? String.format("%,d / %,d", (long)terr, (long)nrows):
String.format("%.2f / %.2f", terr, nrows));
return table;
}
}