/*******************************************************************************
* Copyright (c) 2010 Haifeng Li
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
package smile.validation;
import smile.sort.QuickSort;
/**
* The area under the curve (AUC). When using normalized units, the area under
* the curve is equal to the probability that a classifier will rank a
* randomly chosen positive instance higher than a randomly chosen negative
* one (assuming 'positive' ranks higher than 'negative').
* <p>
* In statistics, a receiver operating characteristic (ROC), or ROC curve,
* is a graphical plot that illustrates the performance of a binary classifier
* system as its discrimination threshold is varied. The curve is created by
* plotting the true positive rate (TPR) against the false positive rate (FPR)
* at various threshold settings.
* <p>
* AUC is quite noisy as a classification measure and has some other
* significant problems in model comparison.
* <p>
* We calculate AUC based on Mann-Whitney U test
* (https://en.wikipedia.org/wiki/Mann-Whitney_U_test).
*
* @author Haifeng Li
*/
public class AUC {
public AUC() {
}
/**
* Caulculate AUC for binary classifier.
* @param truth The sample labels
* @param probability The posterior probability of positive class.
* @return AUC
*/
public static double measure(int[] truth, double[] probability) {
if (truth.length != probability.length) {
throw new IllegalArgumentException(String.format("The vector sizes don't match: %d != %d.", truth.length, probability.length));
}
// for large sample size, overflow may happen for pos * neg.
// switch to double to prevent it.
double pos = 0;
double neg = 0;
for (int i = 0; i < truth.length; i++) {
if (truth[i] == 0) {
neg++;
} else if (truth[i] == 1) {
pos++;
} else {
throw new IllegalArgumentException("AUC is only for binary classification. Invalid label: " + truth[i]);
}
}
int[] label = truth.clone();
double[] prediction = probability.clone();
QuickSort.sort(prediction, label);
double[] rank = new double[label.length];
for (int i = 0; i < prediction.length; i++) {
if (i == prediction.length - 1 || prediction[i] != prediction[i+1]) {
rank[i] = i + 1;
} else {
int j = i + 1;
for (; j < prediction.length && prediction[j] == prediction[i]; j++);
double r = (i + 1 + j) / 2.0;
for (int k = i; k < j; k++) rank[k] = r;
i = j - 1;
}
}
double auc = 0.0;
for (int i = 0; i < label.length; i++) {
if (label[i] == 1)
auc += rank[i];
}
auc = (auc - (pos * (pos+1) / 2.0)) / (pos * neg);
return auc;
}
}