/* * This file is part of ALOE. * * ALOE is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * ALOE is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * You should have received a copy of the GNU General Public License * along with ALOE. If not, see <http://www.gnu.org/licenses/>. * * Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl) */ package etc.aloe.data; import etc.aloe.processes.Saving; import java.io.IOException; import java.io.OutputStream; import java.io.PrintStream; import java.util.ArrayList; import java.util.Collection; import java.util.List; /** * The EvaluationReport contains data about model performance as compared to a * source of truth data. * * @author Michael Brooks <mjbrooks@uw.edu> */ public class EvaluationReport implements Saving { private int truePositiveCount; private int trueNegativeCount; private int falsePositiveCount; private int falseNegativeCount; private double falsePositiveCost = 1; private double falseNegativeCost = 1; private List<ROC> rocs = new ArrayList<ROC>(); private List<SegmentSet> testSets = new ArrayList<SegmentSet>(); private List<String> testSetNames = new ArrayList<String>(); private final String name; /** * Construct an equal-cost evaluation report */ public EvaluationReport(String name) { this.name = name; } /** * Construct a cost-sensitive evaluation report. * * @param falsePositiveCost * @param falseNegativeCost */ public EvaluationReport(String name, double falsePositiveCost, double falseNegativeCost) { this.name = name; this.falsePositiveCost = falsePositiveCost; this.falseNegativeCost = falseNegativeCost; } /** * Get the number of examples with a positive prediction that was correct. * * @return */ public int getTruePositiveCount() { return truePositiveCount; } /** * Get the number of examples with a negative prediction that was correct. * * @return */ public int getTrueNegativeCount() { return trueNegativeCount; } /** * Get the number of examples with a positive prediction that was incorrect. * * @return */ public int getFalsePositiveCount() { return falsePositiveCount; } /** * Get the number of examples with a negative prediction that was incorrect. * * @return */ public int getFalseNegativeCount() { return falseNegativeCount; } /** * Set the number of examples with a positive prediction that was correct. * * @param truePositiveCount */ public void setTruePositiveCount(int truePositiveCount) { this.truePositiveCount = truePositiveCount; } /** * Set the number of examples with a negative prediction that was correct. * * @param trueNegativeCount */ public void setTrueNegativeCount(int trueNegativeCount) { this.trueNegativeCount = trueNegativeCount; } /** * Set the number of examples with a positive prediction that was incorrect. * * @param falsePositiveCount */ public void setFalsePositiveCount(int falsePositiveCount) { this.falsePositiveCount = falsePositiveCount; } /** * Set the number of examples with a negative prediction that was incorrect. * * @param falseNegativeCount */ public void setFalseNegativeCount(int falseNegativeCount) { this.falseNegativeCount = falseNegativeCount; } /** * Recall = (correctly classified positives) / (total actual positives) * * @return */ public double getRecall() { return (double) truePositiveCount / (truePositiveCount + falseNegativeCount); } /** * Precision = (correctly classified positives) / (total predicted as * positive) * * @return */ public double getPrecision() { return (double) truePositiveCount / (truePositiveCount + falsePositiveCount); } /** * FMeasure = (2 * recall * precision) / (recall + precision) * * @return */ public double getFMeasure() { double precision = getPrecision(); double recall = getRecall(); if ((precision + recall) == 0) { return 0; } return 2 * precision * recall / (precision + recall); } /** * PercentCorrect = (TP + TN) / (TP + TN + FP + FN) * * @return */ public double getPercentCorrect() { return (double) (truePositiveCount + trueNegativeCount) / getTotalExamples(); } /** * PercentIncorrect = (FP + FN) / (TP + TN + FP + FN) * * @return */ public double getPercentIncorrect() { return (double) (falsePositiveCount + falseNegativeCount) / getTotalExamples(); } /** * Get the number of positive examples in the truth data. * * @return */ public int getNumTrulyPositive() { return getTruePositiveCount() + getFalseNegativeCount(); } /** * Get the number of negative examples in the truth data. * * @return */ public int getNumTrulyNegative() { return getTrueNegativeCount() + getFalsePositiveCount(); } /** * Get the number of examples predicted to be positive. * * @return */ public int getNumPredictedPositive() { return getTruePositiveCount() + getFalsePositiveCount(); } /** * Get the number of examples predicted to be negative. * * @return */ public int getNumPredictedNegative() { return getTrueNegativeCount() + getFalseNegativeCount(); } /** * Get Cohen's kappa, the probability of agreement with the truth data, * corrected by the probability of random agreement. * * See http://en.wikipedia.org/wiki/Cohen's_kappa * * @return */ public double getCohensKappa() { double probabilityAgreement = getPercentCorrect(); double numPositiveInTruth = getNumTrulyPositive(); double numPositiveInPrediction = getNumPredictedPositive(); double numNegativeInTruth = getNumTrulyNegative(); double numNegativeInPrediction = getNumPredictedNegative(); double probabilityRandomPositiveAgreement = (numPositiveInTruth / getTotalExamples()) * (numPositiveInPrediction / getTotalExamples()); double probabilityRandomNegativeAgreement = (numNegativeInTruth / getTotalExamples()) * (numNegativeInPrediction / getTotalExamples()); double probabilityRandomAgreement = probabilityRandomPositiveAgreement + probabilityRandomNegativeAgreement; double kappa = (probabilityAgreement - probabilityRandomAgreement) / (1 - probabilityRandomAgreement); return kappa; } /** * Gets the total cost of all misclassified examples. * * @return */ public double getTotalCost() { return falsePositiveCount * falsePositiveCost + falseNegativeCount * falseNegativeCost; } /** * Gets the cost of all misclassified examples divided by the total number * of examples. * * @return */ public double getAverageCost() { return getTotalCost() / getTotalExamples(); } /** * Get the total number of examples classified. * * @return */ public int getTotalExamples() { return truePositiveCount + trueNegativeCount + falsePositiveCount + falseNegativeCount; } @Override public boolean save(OutputStream destination) throws IOException { PrintStream out = new PrintStream(destination); out.print(this.toString()); out.flush(); return true; } /** * Get the evaluation report as a string. * * @return */ @Override public String toString() { return "Examples: " + getTotalExamples() + "\n" + "Positive: " + getNumTrulyPositive() + "\n" + "Negative: " + getNumTrulyNegative() + "\n" + "FP Cost: " + falsePositiveCost + "\n" + "FN Cost: " + falseNegativeCost + "\n" + "------------------\n" + "TP: " + truePositiveCount + "\n" + "FP: " + falsePositiveCount + "\n" + "TN: " + trueNegativeCount + "\n" + "FN: " + falseNegativeCount + "\n" + "------------------\n" + "Positive Predictions: " + getNumPredictedPositive() + "\n" + "Negative Predictions: " + getNumPredictedNegative() + "\n" + "------------------\n" + "Precision: " + getPrecision() + "\n" + "Recall: " + getRecall() + "\n" + "FMeasure: " + getFMeasure() + "\n" + "------------------\n" + "% Correct: " + getPercentCorrect() + "\n" + "% Incorrect: " + getPercentIncorrect() + "\n" + "Kappa: " + getCohensKappa() + "\n" + "------------------\n" + "Total Cost: " + getTotalCost() + "\n" + "Avg Cost: " + getAverageCost(); } /** * Add a partial evaluation report to this report. Modifies the current * report. * * @param report */ public void addPartial(EvaluationReport report) { truePositiveCount += report.truePositiveCount; trueNegativeCount += report.trueNegativeCount; falsePositiveCount += report.falsePositiveCount; falseNegativeCount += report.falseNegativeCount; this.rocs.addAll(report.getROCs()); this.testSets.addAll(report.getTestSets()); this.testSetNames.addAll(report.getTestSetNames()); } /** * Add some test data with labels to the report, for later export. * * @param testingSegments */ public void addLabeledTestData(SegmentSet testingSegments) { this.testSets.add(testingSegments); this.testSetNames.add(this.getName()); } /** * Evaluate the given predictions. * * @param predictions */ public void addPredictions(Predictions predictions) { ROC roc = new ROC(this.getName()); roc.calculateCurve(predictions); this.rocs.add(roc); this.setTruePositiveCount(predictions.getTruePositiveCount()); this.setFalsePositiveCount(predictions.getFalsePositiveCount()); this.setTrueNegativeCount(predictions.getTrueNegativeCount()); this.setFalseNegativeCount(predictions.getFalseNegativeCount()); } public String getName() { return name; } /** * Get the labeled test sets from cross validation. * @return */ public List<SegmentSet> getTestSets() { return testSets; } /** * Get a list of named ROC curves included in this report. * * @return */ public List<ROC> getROCs() { return rocs; } public List<String> getTestSetNames() { return this.testSetNames; } }