package water.api;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.TransfVec;
import water.fvec.Vec;
import water.util.Utils;
import java.util.Arrays;
import static water.util.Utils.printConfusionMatrix;
/**
* Compare two categorical columns, reporting a grid of co-occurrences.
* <br>
* The semantics follows R-approach - see R code:
* <pre>
* > l = c("A", "B", "C")
* > a = factor(c("A", "B", "C"), levels=l)
* > b = factor(c("A", "B", "A"), levels=l)
* > confusionMatrix(a,b)
*
* Reference
* Prediction A B C
* A 1 0 0
* B 0 1 0
* C 1 0 0
* </pre>
*
* <p>Note: By default we report zero rows and columns.</p>
*
* @author cliffc
*/
public class ConfusionMatrix extends Func {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.
@API(help = "", required = true, filter = Default.class)
public Frame actual;
@API(help="Column of the actual results (will display vertically)", required=true, filter=actualVecSelect.class)
public Vec vactual;
class actualVecSelect extends VecClassSelect { actualVecSelect() { super("actual"); } }
@API(help = "", required = true, filter = Default.class)
public Frame predict;
@API(help="Column of the predicted results (will display horizontally)", required=true, filter=predictVecSelect.class)
public Vec vpredict;
class predictVecSelect extends VecClassSelect { predictVecSelect() { super("predict"); } }
@API(help="domain of the actual response")
String [] actual_domain;
@API(help="domain of the predicted response")
String [] predicted_domain;
@API(help="union of domains")
public
String [] domain;
@API(help="Confusion Matrix (or co-occurrence matrix)")
public long cm[][];
@API(help="Mean Squared Error")
public double mse = Double.NaN;
private boolean classification;
@Override protected void init() throws IllegalArgumentException {
classification = vactual.isInt() && vpredict.isInt();
// Input handling
if( vactual==null || vpredict==null )
throw new IllegalArgumentException("Missing actual or predict!");
if (vactual.length() != vpredict.length())
throw new IllegalArgumentException("Both arguments must have the same length!");
// Handle regression kind which is producing CM 1x1 elements
if (!classification && vactual.isEnum())
throw new IllegalArgumentException("Actual vector cannot be categorical for regression scoring.");
if (!classification && vpredict.isEnum())
throw new IllegalArgumentException("Predicted vector cannot be categorical for regression scoring.");
}
@Override protected void execImpl() {
Vec va = null,vp = null, avp = null;
try {
if (classification) {
// Create a new vectors - it is cheap since vector are only adaptation vectors
va = vactual .toEnum(); // always returns TransfVec
actual_domain = va._domain;
vp = vpredict.toEnum(); // always returns TransfVec
predicted_domain = vp._domain;
if (!Arrays.equals(actual_domain, predicted_domain)) {
domain = Utils.domainUnion(actual_domain, predicted_domain);
int[][] vamap = Model.getDomainMapping(domain, actual_domain, true);
va = TransfVec.compose( (TransfVec) va, vamap, domain, false ); // delete original va
int[][] vpmap = Model.getDomainMapping(domain, predicted_domain, true);
vp = TransfVec.compose( (TransfVec) vp, vpmap, domain, false ); // delete original vp
} else domain = actual_domain;
// The vectors are from different groups => align them, but properly delete it after computation
if (!va.group().equals(vp.group())) {
avp = vp;
vp = va.align(vp);
}
cm = new CM(domain.length).doAll(va,vp)._cm;
} else {
mse = new CM(1).doAll(vactual,vpredict).mse();
}
return;
} finally { // Delete adaptation vectors
if (va!=null) UKV.remove(va._key);
if (vp!=null) UKV.remove(vp._key);
if (avp!=null) UKV.remove(avp._key);
}
}
// Compute the co-occurrence matrix
private static class CM extends MRTask2<CM> {
/* @IN */ final int _c_len;
/* @OUT Classification */ long _cm[][];
/* @OUT Regression */ public double mse() { return _count > 0 ? _mse/_count : Double.POSITIVE_INFINITY; }
/* @OUT Regression Helper */ private double _mse;
/* @OUT Regression Helper */ private long _count;
CM(int c_len) { _c_len = c_len; }
@Override public void map( Chunk ca, Chunk cp ) {
//classification
if (_c_len > 1) {
_cm = new long[_c_len+1][_c_len+1];
int len = Math.min(ca._len,cp._len); // handle different lenghts, but the vectors should have been rejected already
for( int i=0; i < len; i++ ) {
int a=ca.isNA0(i) ? _c_len : (int)ca.at80(i);
int p=cp.isNA0(i) ? _c_len : (int)cp.at80(i);
_cm[a][p]++;
}
if( len < ca._len )
for( int i=len; i < ca._len; i++ )
_cm[ca.isNA0(i) ? _c_len : (int)ca.at80(i)][_c_len]++;
if( len < cp._len )
for( int i=len; i < cp._len; i++ )
_cm[_c_len][cp.isNA0(i) ? _c_len : (int)cp.at80(i)]++;
} else {
_cm = null;
_mse = 0;
assert(ca._len == cp._len);
int len = ca._len;
for( int i=0; i < len; i++ ) {
if (ca.isNA0(i) || cp.isNA0(i)) continue; //TODO: Improve
final double a=ca.at0(i);
final double p=cp.at0(i);
_mse += (p-a)*(p-a);
_count++;
}
}
}
@Override public void reduce( CM cm ) {
if (_cm != null && cm._cm != null) {
Utils.add(_cm,cm._cm);
} else {
assert(_mse != Double.NaN && cm._mse != Double.NaN);
assert(_cm == null && cm._cm == null);
_mse += cm._mse;
_count += cm._count;
}
}
}
@Override public boolean toHTML( StringBuilder sb ) {
if (classification) {
DocGen.HTML.section(sb,"Confusion Matrix");
if( cm == null ) return true;
printConfusionMatrix(sb, cm, domain, true);
} else{
DocGen.HTML.section(sb,"Mean Squared Error");
if( mse == Double.NaN ) return true;
DocGen.HTML.arrayHead(sb);
sb.append("<tr class='warning'><td>" + mse + "</td></tr>");
DocGen.HTML.arrayTail(sb);
}
return true;
}
public void toASCII( StringBuilder sb ) {
if (classification) {
if(cm == null) return;
printConfusionMatrix(sb, cm, domain, false);
} else {
sb.append("MSE: " + mse);
}
}
}