package hex; import static water.api.DocGen.FieldDoc; import static water.util.Utils.printConfusionMatrix; import dontweave.gson.JsonArray; import dontweave.gson.JsonPrimitive; import water.Iced; import water.api.Request.API; import java.util.Arrays; public class ConfusionMatrix extends Iced { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. @API(help="Confusion matrix (Actual/Predicted)") public long[][] _arr; // [actual][predicted] @API(help = "Prediction error by class") public final double[] _classErr; @API(help = "Prediction error") public double _predErr; @Override public ConfusionMatrix clone() { ConfusionMatrix res = new ConfusionMatrix(0); res._arr = _arr.clone(); for( int i = 0; i < _arr.length; ++i ) res._arr[i] = _arr[i].clone(); return res; } public enum ErrMetric { MAXC, SUMC, TOTAL; public double computeErr(ConfusionMatrix cm) { double[] cerr = cm.classErr(); double res = 0; switch( this ) { case MAXC: res = cerr[0]; for( double d : cerr ) if( d > res ) res = d; break; case SUMC: for( double d : cerr ) res += d; break; case TOTAL: res = cm.err(); break; default: throw new RuntimeException("unexpected err metric " + this); } return res; } } public ConfusionMatrix(int n) { _arr = new long[n][n]; _classErr = classErr(); _predErr = err(); } public ConfusionMatrix(long[][] value) { _arr = value; _classErr = classErr(); _predErr = err(); } public ConfusionMatrix(long[][] value, int dim) { _arr = new long[dim][dim]; for (int i=0; i<dim; ++i) for (int j=0; j<dim; ++j) _arr[i][j] = value[i][j]; _classErr = classErr(); _predErr = err(); } public void add(int i, int j) { _arr[i][j]++; } public double[] classErr() { double[] res = new double[_arr.length]; for( int i = 0; i < res.length; ++i ) res[i] = classErr(i); return res; } public final int size() { return _arr.length; } public void reComputeErrors(){ for(int i = 0; i < _arr.length; ++i) _classErr[i] = classErr(i); _predErr = err(); } public final long classErrCount(int c) { long s = 0; for( long x : _arr[c] ) s += x; return s - _arr[c][c]; } public final double classErr(int c) { long s = 0; for( long x : _arr[c] ) s += x; if( s == 0 ) return 0.0; // Either 0 or NaN, but 0 is nicer return (double) (s - _arr[c][c]) / s; } public long totalRows() { long n = 0; for( int a = 0; a < _arr.length; ++a ) for( int p = 0; p < _arr[a].length; ++p ) n += _arr[a][p]; return n; } public void add(ConfusionMatrix other) { water.util.Utils.add(_arr, other._arr); } /** * @return overall classification error */ public double err() { long n = totalRows(); long err = n; for( int d = 0; d < _arr.length; ++d ) err -= _arr[d][d]; return (double) err / n; } public long errCount() { long n = totalRows(); long err = n; for( int d = 0; d < _arr.length; ++d ) err -= _arr[d][d]; return err; } /** * The percentage of predictions that are correct. */ public double accuracy() { return 1-err(); } /** * The percentage of negative labeled instances that were predicted as negative. * @return TNR / Specificity */ public double specificity() { if(!isBinary())throw new UnsupportedOperationException("specificity is only implemented for 2 class problems."); double tn = _arr[0][0]; double fp = _arr[0][1]; return tn / (tn + fp); } /** * The percentage of positive labeled instances that were predicted as positive. * @return Recall / TPR / Sensitivity */ public double recall() { if(!isBinary())throw new UnsupportedOperationException("recall is only implemented for 2 class problems."); double tp = _arr[1][1]; double fn = _arr[1][0]; return tp / (tp + fn); } /** * The percentage of positive predictions that are correct. * @return Precision */ public double precision() { if(!isBinary())throw new UnsupportedOperationException("precision is only implemented for 2 class problems."); double tp = _arr[1][1]; double fp = _arr[0][1]; return tp / (tp + fp); } /** * The Matthews Correlation Coefficient, takes true negatives into account in contrast to F-Score * See <a href="">MCC</a> * MCC = Correlation between observed and predicted binary classification * @return mcc ranges from -1 (total disagreement) ... 0 (no better than random) ... 1 (perfect) */ public double mcc() { if(!isBinary())throw new UnsupportedOperationException("precision is only implemented for 2 class problems."); double tn = _arr[0][0]; double fp = _arr[0][1]; double tp = _arr[1][1]; double fn = _arr[1][0]; double mcc = (tp*tn - fp*fn)/Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)); return mcc; } /** * The maximum per-class error * @return max(classErr(i)) */ public double max_per_class_error() { int n = nclasses(); if(n == 0)throw new UnsupportedOperationException("max per class error is only defined for classification problems"); double res = classErr(0); for(int i = 1; i < n; ++i) res = Math.max(res,classErr(i)); return res; } public final int nclasses(){return _arr == null?0:_arr.length;} public final boolean isBinary(){return nclasses() == 2;} /** * Returns the F-measure which combines precision and recall in a balanced way. <br> * See <a href="">Precision_and_recall</a> */ public double F1() { final double precision = precision(); final double recall = recall(); return 2. * (precision * recall) / (precision + recall); } /** * Returns the F-measure which combines precision and recall and weights recall higher than precision. <br> * See <a href="">F1_score</a> */ public double F2() { final double precision = precision(); final double recall = recall(); return 5. * (precision * recall) / (4. * precision + recall); } /** * Returns the F-measure which combines precision and recall and weights precision higher than recall. <br> * See <a href="">F1_score</a> */ public double F0point5() { final double precision = precision(); final double recall = recall(); return 1.25 * (precision * recall) / (.25 * precision + recall); } @Override public String toString() { StringBuilder sb = new StringBuilder(); for( long[] r : _arr ) sb.append(Arrays.toString(r) + "\n"); return sb.toString(); } public JsonArray toJson() { JsonArray res = new JsonArray(); JsonArray header = new JsonArray(); header.add(new JsonPrimitive("Actual / Predicted")); for( int i = 0; i < _arr.length; ++i ) header.add(new JsonPrimitive("class " + i)); header.add(new JsonPrimitive("Error")); res.add(header); for( int i = 0; i < _arr.length; ++i ) { JsonArray row = new JsonArray(); row.add(new JsonPrimitive("class " + i)); long s = 0; for( int j = 0; j < _arr.length; ++j ) { s += _arr[i][j]; row.add(new JsonPrimitive(_arr[i][j])); } double err = s - _arr[i][i]; err /= s; row.add(new JsonPrimitive(err)); res.add(row); } JsonArray totals = new JsonArray(); totals.add(new JsonPrimitive("Totals")); long S = 0; long DS = 0; for( int i = 0; i < _arr.length; ++i ) { long s = 0; for( int j = 0; j < _arr.length; ++j ) s += _arr[j][i]; totals.add(new JsonPrimitive(s)); S += s; DS += _arr[i][i]; } double err = (S - DS) / (double) S; totals.add(new JsonPrimitive(err)); res.add(totals); return res; } public void toHTML(StringBuilder sb, String[] domain) { long[][] cm = _arr; printConfusionMatrix(sb, cm, domain, true); } }