package hex;
import java.util.Arrays;
import org.junit.*;
import water.*;
import water.fvec.Frame;
import water.util.ArrayUtils;
public class AUCTest extends TestUtil {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
@Test public void testAUC0() {
double auc0 = AUC2.perfectAUC(new double[]{0,0.5,0.5,1}, new double[]{0,0,1,1});
Assert.assertEquals(0.875,auc0,1e-7);
// Flip the tied actuals
double auc1 = AUC2.perfectAUC(new double[]{0,0.5,0.5,1}, new double[]{0,1,0,1});
Assert.assertEquals(0.875,auc1,1e-7);
// Area is 10/12 (TPS=4, FPS=3, so area is 4x3 or 12 units; 10 are under).
double auc2 = AUC2.perfectAUC(new double[]{0.1,0.2,0.3,0.4,0.5,0.6,0.7}, new double[]{0,0,1,1,0,1,1});
Assert.assertEquals(0.8333333,auc2,1e-7);
// Sorted probabilities. At threshold 1e-6 flips from false to true, on
// average. However, there are a lot of random choices at 1e-3.
double probs[] = new double[]{1e-8,1e-7,1e-6,1e-5,1e-4,1e-3,1e-3,1e-3,1e-3,1e-3,1e-3,1e-3,1e-3,1e-3,1e-3,1e-2,1e-1};
double actls[] = new double[]{ 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1};
// Positives & Negatives
int P = 0;
for( double a : actls ) P += (int)a;
int N = actls.length - P;
System.out.println("P="+P+", N="+N);
// Compute TP & FP for all thresholds
double thresh[] = new double[]{1e-1,1e-2,1e-3+1e-9,1e-3,1e-3-1e-9,1e-4,1e-5,1e-6,1e-7,1e-8,0};
int tp[] = new int[thresh.length], fp[] = new int[thresh.length];
int tn[] = new int[thresh.length], fn[] = new int[thresh.length];
for( int i=0; i<probs.length; i++ ) {
for( int t=0; t<thresh.length; t++ ) {
if( probs[i] >= thresh[t] ) // Not interested if below threshold
if( actls[i]==0.0 ) fp[t]++; // False positive
else tp[t]++; // True positive
else
if( actls[i]==0.0 ) tn[t]++; // True negative
else fn[t]++; // False negative
}
}
System.out.println(Arrays.toString(tp));
System.out.println(Arrays.toString(fp));
System.out.println(Arrays.toString(fn));
System.out.println(Arrays.toString(tn));
for( int i=0; i<tp.length; i++ ) System.out.print("{"+((double)tp[i]/P)+","+((double)fp[i]/N)+"} ");
System.out.println();
// The AUC for this dataset, according to R's ROCR package, is 0.6363636363
Assert.assertEquals(doAUC(probs,actls),0.636363636363,1e-5);
Assert.assertEquals(AUC2.perfectAUC(probs,actls),0.636363636363,1e-7);
// Shuffle, check again
swap(0, 5, probs, actls);
swap(1, 6, probs, actls);
swap(7, 15, probs, actls);
Assert.assertEquals(doAUC(probs,actls),0.636363636363,1e-5);
Assert.assertEquals(AUC2.perfectAUC(probs,actls),0.636363636363,1e-7);
// Now from a large test file
double ROCR_auc = 0.7244389;
Frame fr = parse_test_file("smalldata/junit/auc.csv.gz");
// Slow; used to confirm the accuracy as we increase bin counts
//for( int i=10; i<1000; i+=10 ) {
// AUC2 auc = new AUC2(i,fr.vec("V1"),fr.vec("V2"));
// System.out.println("bins="+i+", aucERR="+Math.abs(auc._auc-ROCR_auc)/ROCR_auc);
// Assert.assertEquals(fr.numRows(), auc._p+auc._n);
//}
double aucp = AUC2.perfectAUC(fr.vec("V1"), fr.vec("V2"));
Assert.assertEquals(ROCR_auc, aucp, 1e-4);
AUC2 auc = new AUC2(fr.vec("V1"), fr.vec("V2"));
Assert.assertEquals(ROCR_auc, auc._auc, 1e-4);
Assert.assertEquals(1.0, AUC2.ThresholdCriterion.precision.max_criterion(auc), 1e-4);
double ROCR_max_abs_mcc = 0.4553512;
Assert.assertEquals(ROCR_max_abs_mcc, AUC2.ThresholdCriterion.absolute_mcc.max_criterion(auc), 1e-3);
double ROCR_f1 = 0.9920445; // same as ROCR "f" with alpha=0, or alternative beta=1
Assert.assertEquals(ROCR_f1, AUC2.ThresholdCriterion.f1.max_criterion(auc), 1e-4);
fr.remove();
}
private static double doAUC(double probs[], double actls[]) {
double rows[][] = new double[probs.length][];
for( int i=0; i<probs.length; i++ )
rows[i] = new double[]{probs[i],actls[i]};
Frame fr = ArrayUtils.frame(new String[]{"probs", "actls"}, rows);
AUC2 auc = new AUC2(fr.vec("probs"),fr.vec("actls"));
fr.remove();
for( int i=0; i<auc._nBins; i++ ) System.out.print("{"+((double)auc._tps[i]/auc._p)+","+((double)auc._fps[i]/auc._n)+"} ");
System.out.println();
for( int i=0; i<auc._nBins; i++ ) System.out.print(AUC2.ThresholdCriterion.min_per_class_accuracy.exec(auc,i)+" ");
System.out.println();
return auc._auc;
}
private static void swap(int x, int y, double probs[], double actls[]) {
double tmp0 = probs[x]; probs[x] = probs[y]; probs[y] = tmp0;
double tmp1 = actls[x]; actls[x] = actls[y]; actls[y] = tmp1;
}
}