package quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions;
import quickml.supervised.crossValidation.PredictionMapResult;
import quickml.supervised.crossValidation.PredictionMapResults;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction;
import java.io.Serializable;
import java.util.*;
/**
* AUCCrossValLoss calculates the ROC area over the curve to determine loss.
* Created by Chris on 5/5/2014.
*/
public class WeightedAUCCrossValLossFunction extends ClassifierLossFunction {
private final Serializable positiveClassification;
public WeightedAUCCrossValLossFunction(Serializable positiveClassification) {
this.positiveClassification = positiveClassification;
}
@Override
public Double getLoss(PredictionMapResults results) {
List<AUCData> aucDataList = getAucDataList(results);
//order by probabilityOfPositiveClassification ascending
Collections.sort(aucDataList);
ArrayList<AUCPoint> aucPoints = getAUCPointsFromData(aucDataList);
return getAUCLoss(aucPoints);
}
@Override
public String getName() {
return "WEIGHTED_AUC";
}
private List<AUCData> getAucDataList(PredictionMapResults results) {
ensureBinaryClassifications(results);
List<AUCData> aucDataList = new ArrayList<>();
for (PredictionMapResult result : results) {
double probabilityOfPositiveClassification = result.getPrediction().get(positiveClassification);
aucDataList.add(new AUCData(result.getLabel(), result.getWeight(), probabilityOfPositiveClassification));
}
return aucDataList;
}
private void ensureBinaryClassifications(PredictionMapResults results) {
Set<Serializable> classifications = new HashSet<>();
for (PredictionMapResult result : results) {
classifications.add(result.getLabel());
if (classifications.size() > 2) {
throw new RuntimeException("AUCCrossValLoss only supports binary classifications");
}
}
}
public ArrayList<AUCPoint> getAUCPointsFromData(List<AUCData> aucDataList) {
double truePositives = 0;
double trueNegatives = 0;
double falsePositives = 0;
double falseNegatives = 0;
ArrayList<AUCPoint> aucPoints = new ArrayList<>();
double thresholdForPositiveClassification = 0.0;
//start at upper right of ROC CURVE where everything is a positive
for (AUCData aucData : aucDataList) {
if (aucData.getClassification().equals(positiveClassification)) {
truePositives += aucData.getWeight();
} else {
falsePositives += aucData.getWeight();
}
}
//add 1,1 since we won't get it if we always predict 0.0
aucPoints.add(getAUCPoint(truePositives, falsePositives, trueNegatives, falseNegatives));
//iterate through each data point updating all points that are changed by the threshold
int startIndex = 0;
double probabilityOfNext = aucDataList.get(0).getProbabilityOfPositiveClassification();
while (probabilityOfNext <=0.0 && startIndex<aucDataList.size()) {
AUCData aucData = aucDataList.get(startIndex);
if (aucData.getClassification().equals(positiveClassification)) {
//add a false negative
falseNegatives += aucData.getWeight();
//remove true positive from previous threshold
truePositives -= aucData.getWeight();
} else {//we are negative and guessing negative
//add a true negative
trueNegatives += aucData.getWeight();
//remove a false positive from previous threshold
falsePositives -= aucData.getWeight();
}
startIndex++;
probabilityOfNext = aucData.getProbabilityOfPositiveClassification();
}
//now compute the non endpoint ROC curve points
for (int i = startIndex; i< aucDataList.size(); i++) {
//each computed probability of positive classification is used as a threshold (in ascending order
//which maps to the the upper right of the ROC curve.
// At each threshold, we know that at most one data point changed to be classified
// as a negative (and thus know the complete count of TPs FPs, TN, FN at that ROC point
//note, we make the threshold inclusive, in the sense that points are labeled positives if they are
//less than the threshold
AUCData aucData = aucDataList.get(i);
double probability = aucData.getProbabilityOfPositiveClassification();
//no need to double count
if (thresholdForPositiveClassification != probability && probability!=0.0) {
aucPoints.add(getAUCPoint(truePositives, falsePositives, trueNegatives, falseNegatives));
thresholdForPositiveClassification = probability;
}
//point is a positive but with the new threshold, we predict it is negative
if (aucData.getClassification().equals(positiveClassification)) {
//add a false negative
falseNegatives += aucData.getWeight();
//remove true positive from previous threshold
truePositives -= aucData.getWeight();
} else {//we are negative and guessing negative
//add a true negative
trueNegatives += aucData.getWeight();
//remove a false positive from previous threshold
falsePositives -= aucData.getWeight();
}
}
//adds last non (0, 0) point in roc space.
if (truePositives !=0 && falsePositives !=0) {
aucPoints.add(getAUCPoint(truePositives, falsePositives, trueNegatives, falseNegatives));
}
// (1,1)
aucPoints.add(getAUCPoint(0, 0, trueNegatives, falseNegatives));
return aucPoints;
}
public AUCPoint getAUCPoint(double truePositives, double falsePositives, double trueNegatives, double falseNegatives) {
double truePositiveRate = (truePositives + falseNegatives == 0) ? 0 : (truePositives / (truePositives + falseNegatives));
double falsePositiveRate = (falsePositives + trueNegatives == 0) ? 0 : (falsePositives / (falsePositives + trueNegatives));
return new AUCPoint(falsePositiveRate, truePositiveRate);
}
public double getAUCLoss(ArrayList<AUCPoint> aucPoints) {
Collections.sort(aucPoints);
double sumXY = 0.0;
//Area over curve OR AUCLoss = (2 - sum((x1-x0)(y1+y0)))/2
for (int i = 1; i < aucPoints.size(); i++) {
AUCPoint aucPoint1 = aucPoints.get(i);
AUCPoint aucPoint0 = aucPoints.get(i - 1);
sumXY += ((aucPoint1.getFalsePositiveRate() - aucPoint0.getFalsePositiveRate()) * (aucPoint1.getTruePositiveRate() + aucPoint0.getTruePositiveRate()));
}
return (2.0 - sumXY) / 2.0;
}
public static class AUCPoint implements Comparable<AUCPoint> {
private final double truePositiveRate;
private final double falsePositiveRate;
public AUCPoint(double falsePositiveRate, double truePositiveRate) {
this.truePositiveRate = truePositiveRate;
this.falsePositiveRate = falsePositiveRate;
}
public double getFalsePositiveRate() {
return falsePositiveRate;
}
public double getTruePositiveRate() {
return truePositiveRate;
}
@Override
public int compareTo(AUCPoint o) {
//order by false positive rate ascending, true positive rate ascending
if (falsePositiveRate > o.falsePositiveRate) {
return 1;
} else if (falsePositiveRate < o.falsePositiveRate) {
return -1;
} else {
return Double.compare(truePositiveRate, o.truePositiveRate);
}
}
}
public static class AUCData implements Comparable<AUCData> {
private final Serializable classification;
private final double weight;
private final double probabilityOfPositiveClassification;
public AUCData(Serializable classification, double weight, double probabilityOfPositiveClassification) {
this.classification = classification;
this.weight = weight;
this.probabilityOfPositiveClassification = probabilityOfPositiveClassification;
}
public Serializable getClassification() {
return classification;
}
public double getWeight() {
return weight;
}
public double getProbabilityOfPositiveClassification() {
return probabilityOfPositiveClassification;
}
@Override
public int compareTo(AUCData o) {
return Double.compare(probabilityOfPositiveClassification, o.probabilityOfPositiveClassification);
}
}
}