/* * This file is part of ELKI: * Environment for Developing KDD-Applications Supported by Index-Structures * * Copyright (C) 2017 * ELKI Development Team * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package de.lmu.ifi.dbs.elki.evaluation.outlier; import java.util.List; import java.util.regex.Pattern; import de.lmu.ifi.dbs.elki.database.Database; import de.lmu.ifi.dbs.elki.database.DatabaseUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDIter; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDs; import de.lmu.ifi.dbs.elki.database.ids.SetDBIDs; import de.lmu.ifi.dbs.elki.database.relation.DoubleRelation; import de.lmu.ifi.dbs.elki.evaluation.Evaluator; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.math.geometry.XYCurve; import de.lmu.ifi.dbs.elki.result.*; import de.lmu.ifi.dbs.elki.result.outlier.OutlierResult; import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer; import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.PatternParameter; /** * Compute a curve containing the precision values for an outlier detection * method. * * @author Erich Schubert * @since 0.5.0 * * @apiviz.has PRCurve */ public class OutlierPrecisionRecallCurve implements Evaluator { /** * AUC value for PR curve */ public static final String PRAUC_LABEL = "PR AUC"; /** * The logger. */ private static final Logging LOG = Logging.getLogger(OutlierPrecisionRecallCurve.class); /** * Stores the "positive" class. */ private Pattern positiveClassName; /** * Constructor. * * @param positiveClassName Pattern to recognize outliers */ public OutlierPrecisionRecallCurve(Pattern positiveClassName) { super(); this.positiveClassName = positiveClassName; } @Override public void processNewResult(ResultHierarchy hier, Result result) { Database db = ResultUtil.findDatabase(hier); // Prepare SetDBIDs positiveids = DBIDUtil.ensureSet(DatabaseUtil.getObjectsByLabelMatch(db, positiveClassName)); if(positiveids.size() == 0) { LOG.warning("Computing a P/R curve failed - no objects matched."); return; } List<OutlierResult> oresults = OutlierResult.getOutlierResults(result); List<OrderingResult> orderings = ResultUtil.getOrderingResults(result); // Outlier results are the main use case. for(OutlierResult o : oresults) { DBIDs sorted = o.getOrdering().order(o.getOrdering().getDBIDs()); PRCurve curve = computePrecisionResult(o.getScores().size(), positiveids, sorted.iter(), o.getScores()); db.getHierarchy().add(o, curve); EvaluationResult ev = EvaluationResult.findOrCreate(db.getHierarchy(), o, "Evaluation of ranking", "ranking-evaluation"); ev.findOrCreateGroup("Evaluation measures").addMeasure(PRAUC_LABEL, curve.getAUC(), 0., 1., false); // Process them only once. orderings.remove(o.getOrdering()); } // FIXME: find appropriate place to add the derived result // otherwise apply an ordering to the database IDs. for(OrderingResult or : orderings) { DBIDs sorted = or.order(or.getDBIDs()); PRCurve curve = computePrecisionResult(or.getDBIDs().size(), positiveids, sorted.iter(), null); db.getHierarchy().add(or, curve); EvaluationResult ev = EvaluationResult.findOrCreate(db.getHierarchy(), or, "Evaluation of ranking", "ranking-evaluation"); ev.findOrCreateGroup("Evaluation measures").addMeasure(PRAUC_LABEL, curve.getAUC(), 0., 1., false); } } private PRCurve computePrecisionResult(int size, SetDBIDs ids, DBIDIter iter, DoubleRelation scores) { final int postot = ids.size(); int poscnt = 0, total = 0; PRCurve curve = new PRCurve(postot + 2, postot); double prevscore = Double.NaN; for(; iter.valid(); iter.advance()) { // Previous precision rate - y axis final double curprec = ((double) poscnt) / total; // Previous recall rate - x axis final double curreca = ((double) poscnt) / postot; // Analyze next point // positive or negative match? if(ids.contains(iter)) { poscnt += 1; } total += 1; // First iteration ends here if(total == 1) { continue; } // defer calculation for ties if(scores != null) { double curscore = scores.doubleValue(iter); if(Double.compare(prevscore, curscore) == 0) { continue; } prevscore = curscore; } // Add a new point (for the previous entry - because of tie handling!) curve.addAndSimplify(curreca, curprec); } // End curve - always at all positives found. curve.addAndSimplify(1.0, postot / total); return curve; } /** * P/R Curve * * @author Erich Schubert */ public static class PRCurve extends XYCurve { /** * Area under curve */ double auc = Double.NaN; /** * Number of positive observations */ int positive; /** * Constructor. * * @param size Size estimation * @param positive Number of positive elements (for AUC correction) */ public PRCurve(int size, int positive) { super("Recall", "Precision", size); this.positive = positive; } @Override public String getLongName() { return "Precision-Recall-Curve"; } @Override public String getShortName() { return "pr-curve"; } /** * Get AUC value * * @return AUC value */ public double getAUC() { if(Double.isNaN(auc)) { double max = 1 - 1. / positive; auc = areaUnderCurve(this) / max; } return auc; } } /** * Parameterization class. * * @author Erich Schubert * * @apiviz.exclude */ public static class Parameterizer extends AbstractParameterizer { /** * The pattern to identify positive classes. * * <p> * Key: {@code -precision.positive} * </p> */ public static final OptionID POSITIVE_CLASS_NAME_ID = new OptionID("precision.positive", "Class label for the 'positive' class."); protected Pattern positiveClassName = null; @Override protected void makeOptions(Parameterization config) { super.makeOptions(config); PatternParameter positiveClassNameP = new PatternParameter(POSITIVE_CLASS_NAME_ID); if(config.grab(positiveClassNameP)) { positiveClassName = positiveClassNameP.getValue(); } } @Override protected OutlierPrecisionRecallCurve makeInstance() { return new OutlierPrecisionRecallCurve(positiveClassName); } } }