/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import ml.shifu.shifu.container.ConfusionMatrixObject;
import ml.shifu.shifu.container.PerformanceObject;
import ml.shifu.shifu.container.obj.EvalConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.PerformanceResult;
import ml.shifu.shifu.core.eval.AreaUnderCurve;
import ml.shifu.shifu.exception.ShifuErrorCode;
import ml.shifu.shifu.exception.ShifuException;
import ml.shifu.shifu.fs.PathFinder;
import ml.shifu.shifu.fs.ShifuFileUtils;
import ml.shifu.shifu.util.Constants;
import ml.shifu.shifu.util.JSONUtils;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* PerformanceEvaluator class is to evaluate the performance of model. If the
* evaluation data contains the target column, the PR curve will be generated.
*/
public class PerformanceEvaluator {
private static Logger log = LoggerFactory.getLogger(PerformanceEvaluator.class);
private ModelConfig modelConfig;
private EvalConfig evalConfig;
public PerformanceEvaluator(ModelConfig modelConfig, EvalConfig evalConfig) {
this.modelConfig = modelConfig;
this.evalConfig = evalConfig;
}
public void review() throws IOException {
PathFinder pathFinder = new PathFinder(modelConfig);
log.info("Loading confusion matrix in {}",
pathFinder.getEvalMatrixPath(evalConfig, evalConfig.getDataSet().getSource()));
BufferedReader reader = ShifuFileUtils.getReader(pathFinder.getEvalMatrixPath(evalConfig, evalConfig
.getDataSet().getSource()), evalConfig.getDataSet().getSource());
String line = null;
List<ConfusionMatrixObject> matrixList = new ArrayList<ConfusionMatrixObject>();
int cnt = 0;
while((line = reader.readLine()) != null) {
cnt++;
String[] raw = line.split("\\|");
ConfusionMatrixObject matrix = new ConfusionMatrixObject();
matrix.setTp(Double.parseDouble(raw[0]));
matrix.setFp(Double.parseDouble(raw[1]));
matrix.setFn(Double.parseDouble(raw[2]));
matrix.setTn(Double.parseDouble(raw[3]));
matrix.setWeightedTp(Double.parseDouble(raw[4]));
matrix.setWeightedFp(Double.parseDouble(raw[5]));
matrix.setWeightedFn(Double.parseDouble(raw[6]));
matrix.setWeightedTn(Double.parseDouble(raw[7]));
matrix.setScore(Double.parseDouble(raw[8]));
matrixList.add(matrix);
}
if(0 == cnt) {
log.info("No result read, please check EvalConfusionMatrix file");
throw new ShifuException(ShifuErrorCode.ERROR_EVALCONFMTR);
}
reader.close();
review(matrixList, cnt);
}
public void review(long records) throws IOException {
if(0 == records) {
log.info("No result read, please check EvalConfusionMatrix file");
throw new ShifuException(ShifuErrorCode.ERROR_EVALCONFMTR);
}
PathFinder pathFinder = new PathFinder(modelConfig);
log.info("Loading confusion matrix in {}",
pathFinder.getEvalMatrixPath(evalConfig, evalConfig.getDataSet().getSource()));
BufferedReader reader = null;
try {
reader = ShifuFileUtils.getReader(pathFinder.getEvalMatrixPath(evalConfig, evalConfig.getDataSet()
.getSource()), evalConfig.getDataSet().getSource());
review(new CMOIterable(reader), records);
} finally {
IOUtils.closeQuietly(reader);
}
}
private static class CMOIterable implements Iterable<ConfusionMatrixObject> {
private BufferedReader reader = null;
public CMOIterable(BufferedReader reader) {
if(reader == null) {
throw new NullPointerException("reader is null");
}
this.reader = reader;
}
@Override
public Iterator<ConfusionMatrixObject> iterator() {
return new Iterator<ConfusionMatrixObject>() {
private String line;
@Override
public boolean hasNext() {
try {
this.line = CMOIterable.this.reader.readLine();
if(this.line == null) {
return false;
} else {
return true;
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public ConfusionMatrixObject next() {
String[] raw = line.split("\\|");
ConfusionMatrixObject matrix = new ConfusionMatrixObject();
matrix.setTp(Double.parseDouble(raw[0]));
matrix.setFp(Double.parseDouble(raw[1]));
matrix.setFn(Double.parseDouble(raw[2]));
matrix.setTn(Double.parseDouble(raw[3]));
matrix.setWeightedTp(Double.parseDouble(raw[4]));
matrix.setWeightedFp(Double.parseDouble(raw[5]));
matrix.setWeightedFn(Double.parseDouble(raw[6]));
matrix.setWeightedTn(Double.parseDouble(raw[7]));
matrix.setScore(Double.parseDouble(raw[8]));
return matrix;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
}
public void review(Iterable<ConfusionMatrixObject> matrixList, long records) throws IOException {
PathFinder pathFinder = new PathFinder(modelConfig);
// bucketing
PerformanceResult result = bucketing(matrixList, records, evalConfig.getPerformanceBucketNum(), evalConfig
.getDataSet().getWeightColumnName() != null);
Writer writer = null;
try {
writer = ShifuFileUtils.getWriter(pathFinder.getEvalPerformancePath(evalConfig, evalConfig.getDataSet()
.getSource()), evalConfig.getDataSet().getSource());
JSONUtils.writeValue(writer, result);
} catch (IOException e) {
if(writer != null) {
writer.close();
}
}
}
static PerformanceObject setPerformanceObject(ConfusionMatrixObject confMatObject) {
PerformanceObject po = new PerformanceObject();
po.binLowestScore = confMatObject.getScore();
po.tp = confMatObject.getTp();
po.tn = confMatObject.getTn();
po.fp = confMatObject.getFp();
po.fn = confMatObject.getFn();
po.weightedTp = confMatObject.getWeightedTp();
po.weightedTn = confMatObject.getWeightedTn();
po.weightedFp = confMatObject.getWeightedFp();
po.weightedFn = confMatObject.getWeightedFn();
// Action Rate, TP + FP / Total;
po.actionRate = (confMatObject.getTp() + confMatObject.getFp()) / confMatObject.getTotal();
po.weightedActionRate = (confMatObject.getWeightedTp() + confMatObject.getWeightedFp())
/ confMatObject.getWeightedTotal();
// recall = TP / (TP+FN)
po.recall = confMatObject.getTp() / (confMatObject.getTp() + confMatObject.getFn());
po.weightedRecall = confMatObject.getWeightedTp()
/ (confMatObject.getWeightedTp() + confMatObject.getWeightedFn());
// precision = TP / (TP+FP)
po.precision = confMatObject.getTp() / (confMatObject.getTp() + confMatObject.getFp());
po.weightedPrecision = confMatObject.getWeightedTp()
/ (confMatObject.getWeightedTp() + confMatObject.getWeightedFp());
// FPR, False Positive Rate (fp/(fp+tn))
po.fpr = confMatObject.getFp() / (confMatObject.getFp() + confMatObject.getTn());
po.weightedFpr = confMatObject.getWeightedFp()
/ (confMatObject.getWeightedFp() + confMatObject.getWeightedTn());
// Lift tp / (number_action * (number_postive / all_unit))
po.liftUnit = confMatObject.getTp()
/ ((confMatObject.getTp() + confMatObject.getFp()) * (confMatObject.getTp() + confMatObject.getFn()) / confMatObject
.getTotal());
po.weightLiftUnit = confMatObject.getWeightedTp()
/ ((confMatObject.getWeightedTp() + confMatObject.getWeightedFp())
* (confMatObject.getWeightedTp() + confMatObject.getWeightedFn()) / confMatObject
.getWeightedTotal());
return po;
}
public PerformanceResult bucketing(Iterable<ConfusionMatrixObject> results, long records, int numBucket,
boolean isWeight) {
List<PerformanceObject> FPRList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> catchRateList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> gainList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> FPRWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> catchRateWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
List<PerformanceObject> gainWeightList = new ArrayList<PerformanceObject>(numBucket + 1);
int fpBin = 1, tpBin = 1, gainBin = 1, fpWeightBin = 1, tpWeightBin = 1, gainWeightBin = 1;
double binCapacity = 1.0 / numBucket;
PerformanceObject po = null;
boolean isFirst = true;
int i = 0;
for(ConfusionMatrixObject object: results) {
po = setPerformanceObject(object);
if(isFirst) {
// hit rate == NaN
po.precision = 1.0;
po.weightedPrecision = 1.0;
// lift = NaN
po.liftUnit = 0.0;
po.weightLiftUnit = 0.0;
FPRList.add(po);
catchRateList.add(po);
gainList.add(po);
FPRWeightList.add(po);
catchRateWeightList.add(po);
gainWeightList.add(po);
isFirst = false;
} else {
if(po.fpr >= fpBin * binCapacity) {
po.binNum = fpBin++;
FPRList.add(po);
}
if(po.recall >= tpBin * binCapacity) {
po.binNum = tpBin++;
catchRateList.add(po);
}
// prevent 99%
if((double) (i + 1) / records >= gainBin * binCapacity) {
po.binNum = gainBin++;
gainList.add(po);
}
if(po.weightedFpr >= fpWeightBin * binCapacity) {
po.binNum = fpWeightBin++;
FPRWeightList.add(po);
}
if(po.weightedRecall >= tpWeightBin * binCapacity) {
po.binNum = tpWeightBin++;
catchRateWeightList.add(po);
}
if((object.getWeightedTp() + object.getWeightedFp() + 1) / object.getWeightedTotal() >= gainWeightBin
* binCapacity) {
po.binNum = gainWeightBin++;
gainWeightList.add(po);
}
}
i++;
}
logResult(FPRList, "Bucketing False Positive Rate");
if(isWeight) {
logResult(FPRWeightList, "Bucketing Weighted False Positive Rate");
}
logResult(catchRateList, "Bucketing Catch Rate");
if(isWeight) {
logResult(catchRateWeightList, "Bucketing Weighted Catch Rate");
}
logResult(gainList, "Bucketing Action rate");
if(isWeight) {
logResult(gainWeightList, "Bucketing Weighted action rate");
}
PerformanceResult result = new PerformanceResult();
result.version = Constants.version;
result.pr = catchRateList;
result.weightedPr = catchRateWeightList;
result.roc = FPRList;
result.weightedRoc = FPRWeightList;
result.gains = gainList;
result.weightedGains = gainWeightList;
// Calculate area under curve
result.areaUnderRoc = AreaUnderCurve.ofRoc(result.roc);
result.weightedAreaUnderRoc = AreaUnderCurve.ofWeightedRoc(result.weightedRoc);
result.areaUnderPr = AreaUnderCurve.ofPr(result.pr);
result.weightedAreaUnderPr = AreaUnderCurve.ofWeightedPr(result.weightedPr);
logAucResult(result, isWeight);
return result;
}
static void logAucResult(PerformanceResult result, boolean isWeight) {
log.info("AUC value of ROC: {}", result.areaUnderRoc);
log.info("AUC value of PR: {}", result.areaUnderPr);
if(isWeight) {
log.info("AUC value of weighted ROC: {}", result.weightedAreaUnderRoc);
log.info("AUC value of weighted PR: {}", result.weightedAreaUnderPr);
}
}
static void logResult(List<PerformanceObject> list, String info) {
DecimalFormat df = new DecimalFormat("#.####");
String formatString = "%10s %18s %10s %18s %15s %18s %10s %11s %10s";
log.info("Start print: " + info);
log.info(String.format(formatString, "ActionRate", "WeightedActionRate", "Recall", "WeightedRecall",
"Precision", "WeightedPrecision", "FPR", "WeightedFPR", "BinLowestScore"));
for(PerformanceObject po: list) {
log.info(String.format(formatString, df.format(po.actionRate), df.format(po.weightedActionRate),
df.format(po.recall), df.format(po.weightedRecall), df.format(po.precision),
df.format(po.weightedPrecision), df.format(po.fpr), df.format(po.weightedFpr), po.binLowestScore));
}
}
}