package water.api; import static java.util.Arrays.sort; import hex.ConfusionMatrix; import water.*; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; import water.util.Utils; import java.util.HashSet; public class AUC extends Func { static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields static public DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code. public static final String DOC_GET = "AUC"; @API(help = "", required = true, filter = Default.class, json=true) public Frame actual; @API(help="Column of the actual results (will display vertically)", required=true, filter=actualVecSelect.class, json=true) public Vec vactual; class actualVecSelect extends VecClassSelect { actualVecSelect() { super("actual"); } } @API(help = "", required = true, filter = Default.class, json=true) public Frame predict; @API(help="Column of the predicted results (will display horizontally)", required=true, filter=predictVecSelect.class, json=true) public Vec vpredict; class predictVecSelect extends VecClassSelect { predictVecSelect() { super("predict"); } } @API(help = "Thresholds (optional, e.g. 0:1:0.01 or 0.0,0.2,0.4,0.6,0.8,1.0).", required = false, filter = Default.class, json = true) public float[] thresholds; @API(help = "Threshold criterion", filter = Default.class, json = true) public ThresholdCriterion threshold_criterion = ThresholdCriterion.maximum_F1; public enum ThresholdCriterion { maximum_F1, maximum_F2, maximum_F0point5, maximum_Accuracy, maximum_Precision, maximum_Recall, maximum_Specificity, maximum_absolute_MCC, minimizing_max_per_class_Error } @API(help = "AUC Data", json = true) AUCData aucdata; public AUCData data() { return aucdata; } public AUC() {} /** * Constructor for algos that make their own CMs * @param cms ConfusionMatrices * @param thresh Thresholds */ public AUC(hex.ConfusionMatrix[] cms, float[] thresh) { this(cms, thresh, null); } /** * Constructor for algos that make their own CMs * @param cms ConfusionMatrices * @param thresh Thresholds * @param domain Domain */ public AUC(hex.ConfusionMatrix[] cms, float[] thresh, String[] domain) { aucdata = new AUCData().compute(cms, thresh, domain, threshold_criterion); } @Override protected void init() throws IllegalArgumentException { // Input handling if( vactual==null || vpredict==null ) throw new IllegalArgumentException("Missing vactual or vpredict!"); if (vactual.length() != vpredict.length()) throw new IllegalArgumentException("Both arguments must have the same length ("+vactual.length()+"!="+vpredict.length()+")!"); if (!vactual.isInt()) throw new IllegalArgumentException("Actual column must be integer class labels!"); if (vactual.cardinality() != -1 && vactual.cardinality() != 2) throw new IllegalArgumentException("Actual column must contain binary class labels, but found cardinality " + vactual.cardinality() + "!"); if (vpredict.isEnum()) throw new IllegalArgumentException("vpredict cannot be class labels, expect probabilities."); } @Override protected void execImpl() { Vec va = null, vp; try { va = vactual.toEnum(); // always returns TransfVec vp = vpredict; // The vectors are from different groups => align them, but properly delete it after computation if (!va.group().equals(vp.group())) { vp = va.align(vp); } // compute thresholds, if not user-given if (thresholds != null) { sort(thresholds); if (Utils.minValue(thresholds) < 0) throw new IllegalArgumentException("Minimum threshold cannot be negative."); if (Utils.maxValue(thresholds) > 1) throw new IllegalArgumentException("Maximum threshold cannot be greater than 1."); } else { HashSet hs = new HashSet(); final int bins = (int)Math.min(vpredict.length(), 200l); final long stride = Math.max(vpredict.length() / bins, 1); for( int i=0; i<bins; ++i) hs.add(new Float(vpredict.at(i*stride))); //data-driven thresholds TODO: use percentiles (from Summary2?) for (int i=0;i<51;++i) hs.add(new Float(i/50.)); //always add 0.02-spaced thresholds from 0 to 1 // created sorted vector of unique thresholds thresholds = new float[hs.size()]; int i=0; for (Object h : hs) {thresholds[i++] = (Float)h; } sort(thresholds); } // compute CMs aucdata = new AUCData().compute(new AUCTask(thresholds,va.mean()).doAll(va,vp).getCMs(), thresholds, va._domain, threshold_criterion); } catch(Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } finally { // Delete adaptation vectors if (va!=null) UKV.remove(va._key); } } /* return true if a is better than b with respect to criterion criter */ static boolean isBetter(ConfusionMatrix a, ConfusionMatrix b, ThresholdCriterion criter) { if (criter == ThresholdCriterion.maximum_F1) { return (!Double.isNaN(a.F1()) && (Double.isNaN(b.F1()) || a.F1() > b.F1())); } if (criter == ThresholdCriterion.maximum_F2) { return (!Double.isNaN(a.F2()) && (Double.isNaN(b.F2()) || a.F2() > b.F2())); } if (criter == ThresholdCriterion.maximum_F0point5) { return (!Double.isNaN(a.F0point5()) && (Double.isNaN(b.F0point5()) || a.F0point5() > b.F0point5())); } else if (criter == ThresholdCriterion.maximum_Recall) { return (!Double.isNaN(a.recall()) && (Double.isNaN(b.recall()) || a.recall() > b.recall())); } else if (criter == ThresholdCriterion.maximum_Precision) { return (!Double.isNaN(a.precision()) && (Double.isNaN(b.precision()) || a.precision() > b.precision())); } else if (criter == ThresholdCriterion.maximum_Accuracy) { return a.accuracy() > b.accuracy(); } else if (criter == ThresholdCriterion.minimizing_max_per_class_Error) { return a.max_per_class_error() < b.max_per_class_error(); } else if (criter == ThresholdCriterion.maximum_Specificity) { return (!Double.isNaN(a.specificity()) && (Double.isNaN(b.specificity()) || a.specificity() > b.specificity())); } else if (criter == ThresholdCriterion.maximum_absolute_MCC) { return (!Double.isNaN(a.mcc()) && (Double.isNaN(b.mcc()) || Math.abs(a.mcc()) > Math.abs(b.mcc()))); } else { throw new IllegalArgumentException("Unknown threshold criterion."); } } @Override public boolean toHTML( StringBuilder sb ) { return aucdata.toHTML(sb); } public void toASCII( StringBuilder sb ) { aucdata.toASCII(sb); } // Compute CMs for different thresholds via MRTask2 private static class AUCTask extends MRTask2<AUCTask> { /* @OUT CMs */ public final hex.ConfusionMatrix[] getCMs() { return _cms; } private hex.ConfusionMatrix[] _cms; double nullDev; double resDev; final double ymu; /* IN thresholds */ final private float[] _thresh; AUCTask(float[] thresh, double mu) { _thresh = thresh.clone(); ymu = mu; } static final double y_log_y(double y, double mu) { if(y == 0)return 0; if(mu < Double.MIN_NORMAL) mu = Double.MIN_NORMAL; return y * Math.log(y / mu); } public static double binomial_deviance(double yreal, double ymodel){ return 2 * ((y_log_y(yreal, ymodel)) + y_log_y(1 - yreal, 1 - ymodel)); } @Override public void map( Chunk ca, Chunk cp ) { _cms = new hex.ConfusionMatrix[_thresh.length]; for (int i=0;i<_cms.length;++i) _cms[i] = new hex.ConfusionMatrix(2); final int len = Math.min(ca._len, cp._len); for( int i=0; i < len; i++ ) { if (ca.isNA0(i)) continue; // throw new UnsupportedOperationException("Actual class label cannot be a missing value!"); final int a = (int)ca.at80(i); //would be a 0 if double was NaN assert (a == 0 || a == 1) : "Invalid values in vactual: must be binary (0 or 1)."; if (cp.isNA0(i)) { // Log.warn("Skipping predicted NaN."); //some models predict NaN! continue; } final double pr = cp.at0(i); for( int t=0; t < _cms.length; t++ ) { final int p = pr >= _thresh[t]?1:0; _cms[t].add(a, p); } } } @Override public void reduce( AUCTask other ) { for( int i=0; i<_cms.length; ++i) { _cms[i].add(other._cms[i]); } nullDev += other.nullDev; resDev += other.resDev; } @Override public void postGlobal(){ } } }