/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.eval;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ml.shifu.shifu.container.PerformanceObject;
/**
* Class for computing area under curve.
*
* @author xiaobzheng (zheng.xiaobin.roubao@gmail.com)
*/
public final class AreaUnderCurve {
private AreaUnderCurve() {
}
private static final Logger LOG = LoggerFactory.getLogger(AreaUnderCurve.class);
/**
* Compute the area under the line connecting the two input points by the trapezoidal rule. The point
* is stored as two double value which refer to x-coordinate and y-coordinate respectively.
*
* <p>
* Note: x2 is considered to be no less than x1, so that (x2 - x1) >= 0 and the return value is always a nonnegative
* </p>
*
* @param x1
* x-coordinate of first point.
* @param y1
* y-coordinate of first point.
* @param x2
* x-coordinate of second point.
* @param y2
* y-coordinate of second point.
* @return trapezoid area.
*/
public static double trapezoid(double x1, double y1, double x2, double y2) {
return (y2 + y1) * (x2 - x1) / 2.0;
}
/**
* Calculate area under ROC curve based on the PerformanceObject List.
*
* @param roc
* PerformanceObject List contains ROC curve data.
* @return area under ROC. Return 0 if input list is null or the size of list is less than 2.
*/
public static double ofRoc(List<PerformanceObject> roc) {
return calculateArea(roc, Performances.fpr(), Performances.recall());
}
/**
* Calculate area under Weighted ROC curve based on the PerformanceObject List.
*
* @param weightedRoc
* PerformanceObject List contains Weighted ROC curve data.
* @return area under Weighted ROC. Return 0 if input list is null or the size of list is less than 2.
*/
public static double ofWeightedRoc(List<PerformanceObject> weightedRoc) {
return calculateArea(weightedRoc, Performances.weightedFpr(), Performances.weightedRecall());
}
/**
* Calculate area under PR curve based on the PerformanceObject List.
*
* @param pr
* PerformanceObject List contains PR curve data.
* @return area under PR. Return 0 if input list is null or the size of list is less than 2.
*/
public static double ofPr(List<PerformanceObject> pr) {
return calculateArea(pr, Performances.recall(), Performances.precision());
}
/**
* Calculate area under Weighted PR curve based on the PerformanceObject List.
*
* @param weightedPr
* PerformanceObject List contains Weighted PR curve data.
* @return area under Weighted PR. Return 0 if input list is null or the size of list is less than 2.
*/
public static double ofWeightedPr(List<PerformanceObject> weightedPr) {
return calculateArea(weightedPr, Performances.weightedRecall(), Performances.weightedPrecision());
}
/**
* Calculate curve area by trapezoidal rule based on the given PerformanceObject List and extractor.
*
* @param perform
* PerformanceObject List contains curve data.
* @param xExtractor
* PerformanceExtractor instance used extract x of the point from PerformanceObject.
* @param yExtractor
* PerformanceExtractor instance used extract y of the point from PerformanceObject.
* @return the area under the curve. Return 0 if input list is null or the size of list is less than 2.
* @throws IllegalArgumentException
* if the input xExtractor or yExtractor is null.
*/
public static double calculateArea(List<PerformanceObject> perform, PerformanceExtractor xExtractor,
PerformanceExtractor yExtractor) {
if(perform == null) {
LOG.warn("Input PerformanceObject List is null! Maybe you should check the input.");
return 0;
}
if(perform.size() < 2) {
LOG.warn("We need at least 2 point to calculate area! Maybe you should check the input.");
return 0;
}
if(xExtractor == null || yExtractor == null) {
throw new IllegalArgumentException("The xExtractor and yExtractor can't be null!");
}
// accumulate the trapezoid area of every successive two points in the curve.
Iterator<PerformanceObject> iterator = perform.iterator();
double sum = 0.0;
PerformanceObject per = iterator.next();
double x1 = xExtractor.extract(per);
double y1 = yExtractor.extract(per);
double x2;
double y2;
while(iterator.hasNext()) {
per = iterator.next();
x2 = xExtractor.extract(per);
y2 = yExtractor.extract(per);
sum += trapezoid(x1, y1, x2, y2);
x1 = x2;
y1 = y2;
}
return sum;
}
}