package hex; import java.util.Arrays; import water.Iced; import water.MRTask; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Chunk; import water.fvec.Vec; import water.util.fp.Function; import water.util.fp.Functions; import static hex.AUC2.ThresholdCriterion.precision; import static hex.AUC2.ThresholdCriterion.recall; /** One-pass approximate AUC * * This algorithm can compute the AUC in 1-pass with good resolution. During * the pass, it builds an online histogram of the probabilities up to the * resolution (number of bins) asked-for. It also computes the true-positive * and false-positive counts for the histogramed thresholds. With these in * hand, we can compute the TPR (True Positive Rate) and the FPR for the given * thresholds; these define the (X,Y) coordinates of the AUC. */ public class AUC2 extends Iced { public final int _nBins; // Max number of bins; can be less if there are fewer points public final double[] _ths; // Thresholds public final double[] _tps; // True Positives public final double[] _fps; // False Positives public final double _p, _n; // Actual trues, falses public final double _auc, _gini; // Actual AUC value public final int _max_idx; // Threshold that maximizes the default criterion public static final ThresholdCriterion DEFAULT_CM = ThresholdCriterion.f1; // Default bins, good answers on a highly unbalanced sorted (and reverse // sorted) datasets public static final int NBINS = 400; /** Criteria for 2-class Confusion Matrices * * This is an Enum class, with an exec() function to compute the criteria * from the basic parts, and from an AUC2 at a given threshold index. */ public enum ThresholdCriterion { f1(false) { @Override double exec( double tp, double fp, double fn, double tn ) { final double prec = precision.exec(tp,fp,fn,tn); final double recl = tpr .exec(tp,fp,fn,tn); return 2. * (prec * recl) / (prec + recl); } }, f2(false) { @Override double exec( double tp, double fp, double fn, double tn ) { final double prec = precision.exec(tp,fp,fn,tn); final double recl = tpr .exec(tp,fp,fn,tn); return 5. * (prec * recl) / (4. * prec + recl); } }, f0point5(false) { @Override double exec( double tp, double fp, double fn, double tn ) { final double prec = precision.exec(tp,fp,fn,tn); final double recl = tpr .exec(tp,fp,fn,tn); return 1.25 * (prec * recl) / (.25 * prec + recl); } }, accuracy(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return (tn+tp)/(tp+fn+tn+fp); } }, precision(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return tp/(tp+fp); } }, recall(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return tp/(tp+fn); } }, specificity(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return tn/(tn+fp); } }, absolute_mcc(false) { @Override double exec( double tp, double fp, double fn, double tn ) { double mcc = (tp*tn - fp*fn); if (mcc == 0) return 0; mcc /= Math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)); assert(Math.abs(mcc)<=1.) : tp + " " + fp + " " + fn + " " + tn; return Math.abs(mcc); } }, // minimize max-per-class-error by maximizing min-per-class-accuracy. // Report from max_criterion is the smallest correct rate for both classes. // The max min-error-rate is 1.0 minus that. min_per_class_accuracy(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return Math.min(tp/(tp+fn),tn/(tn+fp)); } }, mean_per_class_accuracy(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return 0.5*(tp/(tp+fn) + tn/(tn+fp)); } }, tns(true ) { @Override double exec( double tp, double fp, double fn, double tn ) { return tn; } }, fns(true ) { @Override double exec( double tp, double fp, double fn, double tn ) { return fn; } }, fps(true ) { @Override double exec( double tp, double fp, double fn, double tn ) { return fp; } }, tps(true ) { @Override double exec( double tp, double fp, double fn, double tn ) { return tp; } }, tnr(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return tn/(fp+tn); } }, fnr(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return fn/(fn+tp); } }, fpr(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return fp/(fp+tn); } }, tpr(false) { @Override double exec( double tp, double fp, double fn, double tn ) { return tp/(tp+fn); } }, ; public final boolean _isInt; // Integral-Valued data vs Real-Valued ThresholdCriterion(boolean isInt) { _isInt = isInt; } /** @param tp True Positives (predicted true, actual true ) * @param fp False Positives (predicted true, actual false) * @param fn False Negatives (predicted false, actual true ) * @param tn True Negatives (predicted false, actual false) * @return criteria */ abstract double exec( double tp, double fp, double fn, double tn ); public double exec( AUC2 auc, int idx ) { return exec(auc.tp(idx),auc.fp(idx),auc.fn(idx),auc.tn(idx)); } public double max_criterion( AUC2 auc ) { return exec(auc,max_criterion_idx(auc)); } /** Convert a criterion into a threshold index that maximizes the criterion * @return Threshold index that maximizes the criterion */ public int max_criterion_idx( AUC2 auc ) { double md = -Double.MAX_VALUE; int mx = -1; for( int i=0; i<auc._nBins; i++ ) { double d = exec(auc,i); if( d > md ) { md = d; mx = i; } } return mx; } public static final ThresholdCriterion[] VALUES = values(); } // public enum ThresholdCriterion public double threshold( int idx ) { return _ths[idx]; } public double tp( int idx ) { return _tps[idx]; } public double fp( int idx ) { return _fps[idx]; } public double tn( int idx ) { return _n-_fps[idx]; } public double fn( int idx ) { return _p-_tps[idx]; } /** @return maximum F1 */ public double maxF1() { return ThresholdCriterion.f1.max_criterion(this); } public Function<Integer, Double> forCriterion(final ThresholdCriterion tc) { return new Function<Integer, Double>() { public Double apply(Integer i) { return tc.exec(AUC2.this, i); } }; } /** Default bins, good answers on a highly unbalanced sorted (and reverse * sorted) datasets */ public AUC2( Vec probs, Vec actls ) { this(NBINS,probs,actls); } /** User-specified bin limits. Time taken is product of nBins and rows; * large nBins can be very slow. */ AUC2( int nBins, Vec probs, Vec actls ) { this(new AUC_Impl(nBins).doAll(probs,actls)._bldr); } public AUC2( AUCBuilder bldr ) { // Copy result arrays into base object, shrinking to match actual bins _nBins = bldr._n; assert _nBins >= 1 : "Must have >= 1 bins for AUC calculation, but got " + _nBins; _ths = Arrays.copyOf(bldr._ths,_nBins); _tps = Arrays.copyOf(bldr._tps,_nBins); _fps = Arrays.copyOf(bldr._fps,_nBins); // Reverse everybody; thresholds from 1 down to 0, easier to read for( int i=0; i<((_nBins)>>1); i++ ) { double tmp= _ths[i]; _ths[i] = _ths[_nBins-1-i]; _ths[_nBins-1-i] = tmp ; double tmpt = _tps[i]; _tps[i] = _tps[_nBins-1-i]; _tps[_nBins-1-i] = tmpt; double tmpf = _fps[i]; _fps[i] = _fps[_nBins-1-i]; _fps[_nBins-1-i] = tmpf; } // Rollup counts, so that computing the rates are easier. // The AUC is (TPR,FPR) as the thresholds roll about double p=0, n=0; for( int i=0; i<_nBins; i++ ) { p += _tps[i]; _tps[i] = p; n += _fps[i]; _fps[i] = n; } _p = p; _n = n; _auc = compute_auc(); _gini = 2*_auc-1; _max_idx = DEFAULT_CM.max_criterion_idx(this); } public double pr_auc() { checkRecallValidity(); return Functions.integrate(forCriterion(recall), forCriterion(precision), 0, _nBins-1); } // Checks that recall is monotonic function. // According to Leland, it should be; otherwise it's an error. void checkRecallValidity() { double x0 = recall.exec(this, 0); for (int i = 1; i < _nBins; i++) { double x1 = recall.exec(this, i); if (x0 >= x1) throw new H2OIllegalArgumentException(""+i, "recall", ""+x1 + "<" + x0); } } // Compute the Area Under the Curve, where the curve is defined by (TPR,FPR) // points. TPR and FPR are monotonically increasing from 0 to 1. private double compute_auc() { if (_fps[_nBins-1] == 0) return 1.0; //special case if (_tps[_nBins-1] == 0) return 0.0; //special case // All math is computed scaled by TP and FP. We'll descale once at the // end. Trapezoids from (tps[i-1],fps[i-1]) to (tps[i],fps[i]) double tp0 = 0, fp0 = 0; double area = 0; for( int i=0; i<_nBins; i++ ) { area += (_fps[i]-fp0)*(_tps[i]+tp0)/2.0; // Trapezoid tp0 = _tps[i]; fp0 = _fps[i]; } // Descale return area/_p/_n; } // Build a CM for a threshold index. - typed as doubles because of double observation weights public double[/*actual*/][/*predicted*/] buildCM( int idx ) { // \ predicted: 0 1 // actual 0: TN FP // 1: FN TP return new double[][]{{tn(idx),fp(idx)},{fn(idx),tp(idx)}}; } /** @return the default CM, or null for an empty AUC */ public double[/*actual*/][/*predicted*/] defaultCM( ) { return _max_idx == -1 ? null : buildCM(_max_idx); } /** @return the default threshold; threshold that maximizes the default criterion */ public double defaultThreshold( ) { return _max_idx == -1 ? 0.5 : _ths[_max_idx]; } /** @return the error of the default CM */ public double defaultErr( ) { return _max_idx == -1 ? Double.NaN : (fp(_max_idx)+fn(_max_idx))/(_p+_n); } // Compute an online histogram of the predicted probabilities, along with // true positive and false positive totals in each histogram bin. private static class AUC_Impl extends MRTask<AUC_Impl> { final int _nBins; AUCBuilder _bldr; AUC_Impl( int nBins ) { _nBins = nBins; } @Override public void map( Chunk ps, Chunk as ) { AUCBuilder bldr = _bldr = new AUCBuilder(_nBins); for( int row = 0; row < ps._len; row++ ) if( !ps.isNA(row) && !as.isNA(row) ) bldr.perRow(ps.atd(row),(int)as.at8(row),1); } @Override public void reduce( AUC_Impl auc ) { _bldr.reduce(auc._bldr); } } public static class AUCBuilder extends Iced { final int _nBins; int _n; // Current number of bins final double _ths[]; // Histogram bins, center final double _sqe[]; // Histogram bins, squared error final double _tps[]; // Histogram bins, true positives final double _fps[]; // Histogram bins, false positives // Merging this bin with the next gives the least increase in squared // error, or -1 if not known. Requires a linear scan to find. int _ssx; public AUCBuilder(int nBins) { _nBins = nBins; _ths = new double[nBins<<1]; // Threshold; also the mean for this bin _sqe = new double[nBins<<1]; // Squared error (variance) in this bin _tps = new double[nBins<<1]; // True positives _fps = new double[nBins<<1]; // False positives _ssx = -1; // Unknown best merge bin } public void perRow(double pred, int act, double w ) { // Insert the prediction into the set of histograms in sorted order, as // if its a new histogram bin with 1 count. assert !Double.isNaN(pred); assert act==0 || act==1; // Actual better be 0 or 1 int idx = Arrays.binarySearch(_ths,0,_n,pred); if( idx >= 0 ) { // Found already in histogram; merge results if( act==0 ) _fps[idx]+=w; else _tps[idx]+=w; // One more count; no change in squared error _ssx = -1; // Blows the known best merge return; } idx = -idx-1; // Get index to insert at // If already full bins, try to instantly merge into an existing bin if( _n > _nBins ) { // Need to merge to shrink things final int ssx = find_smallest(); double dssx = compute_delta_error(_ths[ssx+1],k(ssx+1),_ths[ssx],k(ssx)); // See if this point will fold into either the left or right bin // immediately. This is the desired fast-path. double d0 = compute_delta_error(pred,w,_ths[idx ],k(idx )); double d1 = compute_delta_error(_ths[idx+1],k(idx+1),pred,w); if( d0 < dssx || d1 < dssx ) { if( d1 < d0 ) idx++; else d0 = d1; // Pick correct bin double oldk = k(idx); if( act==0 ) _fps[idx]+=w; else _tps[idx]+=w; _ths[idx] = _ths[idx] + (pred-_ths[idx])/oldk; _sqe[idx] = _sqe[idx] + d0; assert ssx == find_smallest(); return; } } // Must insert this point as it's own threshold (which is not insertion // point), either because we have too few bins or because we cannot // instantly merge the new point into an existing bin. if( idx == _ssx ) _ssx = -1; // Smallest error becomes one of the splits else if( idx < _ssx ) _ssx++; // Smallest error will slide right 1 // Slide over to do the insert. Horrible slowness. System.arraycopy(_ths,idx,_ths,idx+1,_n-idx); System.arraycopy(_sqe,idx,_sqe,idx+1,_n-idx); System.arraycopy(_tps,idx,_tps,idx+1,_n-idx); System.arraycopy(_fps,idx,_fps,idx+1,_n-idx); // Insert into the histogram _ths[idx] = pred; // New histogram center _sqe[idx] = 0; // Only 1 point, so no squared error if( act==0 ) { _tps[idx]=0; _fps[idx]=w; } else { _tps[idx]=w; _fps[idx]=0; } _n++; if( _n > _nBins ) // Merge as needed back down to nBins mergeOneBin(); // Merge best pair of bins } public void reduce( AUCBuilder bldr ) { // Merge sort the 2 sorted lists into the double-sized arrays. The tail // half of the double-sized array is unused, but the front half is // probably a source. Merge into the back. //assert sorted(); //assert bldr.sorted(); int x= _n-1; int y=bldr._n-1; while( x+y+1 >= 0 ) { boolean self_is_larger = y < 0 || (x >= 0 && _ths[x] >= bldr._ths[y]); AUCBuilder b = self_is_larger ? this : bldr; int idx = self_is_larger ? x : y ; _ths[x+y+1] = b._ths[idx]; _sqe[x+y+1] = b._sqe[idx]; _tps[x+y+1] = b._tps[idx]; _fps[x+y+1] = b._fps[idx]; if( self_is_larger ) x--; else y--; } _n += bldr._n; //assert sorted(); // Merge elements with least squared-error increase until we get fewer // than _nBins and no duplicates. May require many merges. while( _n > _nBins || dups() ) mergeOneBin(); } private void mergeOneBin( ) { // Too many bins; must merge bins. Merge into bins with least total // squared error. Horrible slowness linear arraycopy. int ssx = find_smallest(); // Merge two bins. Classic bins merging by averaging the histogram // centers based on counts. double k0 = k(ssx); double k1 = k(ssx+1); _ths[ssx] = (_ths[ssx]*k0 + _ths[ssx+1]*k1) / (k0+k1); _sqe[ssx] = _sqe[ssx]+_sqe[ssx+1]+compute_delta_error(_ths[ssx+1],k1,_ths[ssx],k0); _tps[ssx] += _tps[ssx+1]; _fps[ssx] += _fps[ssx+1]; // Slide over to crush the removed bin at index (ssx+1) System.arraycopy(_ths,ssx+2,_ths,ssx+1,_n-ssx-2); System.arraycopy(_sqe,ssx+2,_sqe,ssx+1,_n-ssx-2); System.arraycopy(_tps,ssx+2,_tps,ssx+1,_n-ssx-2); System.arraycopy(_fps,ssx+2,_fps,ssx+1,_n-ssx-2); _n--; _ssx = -1; } // Find the pair of bins that when combined give the smallest increase in // squared error. Dups never increase squared error. // // I tried code for merging bins with keeping the bins balanced in size, // but this leads to bad errors if the probabilities are sorted. Also // tried the original: merge bins with the least distance between bin // centers. Same problem for sorted data. private int find_smallest() { if( _ssx == -1 ) return (_ssx = find_smallest_impl()); assert _ssx == find_smallest_impl(); return _ssx; } private int find_smallest_impl() { double minSQE = Double.MAX_VALUE; int minI = -1; int n = _n; for( int i=0; i<n-1; i++ ) { double derr = compute_delta_error(_ths[i+1],k(i+1),_ths[i],k(i)); if( derr == 0 ) return i; // Dup; no increase in SQE so return immediately double sqe = _sqe[i]+_sqe[i+1]+derr; if( sqe < minSQE ) { minI = i; minSQE = sqe; } } return minI; } private boolean dups() { int n = _n; for( int i=0; i<n-1; i++ ) { double derr = compute_delta_error(_ths[i+1],k(i+1),_ths[i],k(i)); if( derr == 0 ) { _ssx = i; return true; } } return false; } private double compute_delta_error( double ths1, double n1, double ths0, double n0 ) { // If thresholds vary by less than a float ULP, treat them as the same. // Some models only output predictions to within float accuracy (so a // variance here is junk), and also it's not statistically sane to have // a model which varies predictions by such a tiny change in thresholds. double delta = (float)ths1-(float)ths0; // Parallel equation drawn from: // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm return delta*delta*n0*n1 / (n0+n1); } private double k( int idx ) { return _tps[idx]+_fps[idx]; } //private boolean sorted() { // double t = _ths[0]; // for( int i=1; i<_n; i++ ) { // if( _ths[i] < t ) // return false; // t = _ths[i]; // } // return true; //} } // ========== // Given the probabilities of a 1, and the actuals (0/1) report the perfect // AUC found by sorting the entire dataset. Expensive, and only works for // small data (probably caps out at about 10M rows). public static double perfectAUC( Vec vprob, Vec vacts ) { if( vacts.min() < 0 || vacts.max() > 1 || !vacts.isInt() ) throw new IllegalArgumentException("Actuals are either 0 or 1"); if( vprob.min() < 0 || vprob.max() > 1 ) throw new IllegalArgumentException("Probabilities are between 0 and 1"); // Horrible data replication into array of structs, to sort. Pair[] ps = new Pair[(int)vprob.length()]; Vec.Reader rprob = vprob.new Reader(); Vec.Reader racts = vacts.new Reader(); for( int i=0; i<ps.length; i++ ) ps[i] = new Pair(rprob.at(i),(byte)racts.at8(i)); return perfectAUC(ps); } public static double perfectAUC( double ds[], double[] acts ) { Pair[] ps = new Pair[ds.length]; for( int i=0; i<ps.length; i++ ) ps[i] = new Pair(ds[i],(byte)acts[i]); return perfectAUC(ps); } private static double perfectAUC( Pair[] ps ) { // Sort by probs, then actuals - so tied probs have the 0 actuals before // the 1 actuals. Sort probs from largest to smallest - so both the True // and False Positives are zero to start. Arrays.sort(ps,new java.util.Comparator<Pair>() { @Override public int compare( Pair a, Pair b ) { return a._prob<b._prob ? 1 : (a._prob==b._prob ? (b._act-a._act) : -1); } }); // Compute Area Under Curve. // All math is computed scaled by TP and FP. We'll descale once at the // end. Trapezoids from (tps[i-1],fps[i-1]) to (tps[i],fps[i]) int tp0=0, fp0=0, tp1=0, fp1=0; double prob = 1.0; double area = 0; for( Pair p : ps ) { if( p._prob!=prob ) { // Tied probabilities: build a diagonal line area += (fp1-fp0)*(tp1+tp0)/2.0; // Trapezoid tp0 = tp1; fp0 = fp1; prob = p._prob; } if( p._act==1 ) tp1++; else fp1++; } area += (double)tp0*(fp1-fp0); // Trapezoid: Rectangle + area += (double)(tp1-tp0)*(fp1-fp0)/2.0; // Right Triangle // Descale return area/tp1/fp1; } private static class Pair { final double _prob; final byte _act; Pair( double prob, byte act ) { _prob = prob; _act = act; } } }