/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.evaluation; import java.util.List; import javax.annotation.Nonnull; /** * Binary responses measures for item recommendation (i.e. ranking problems) * * References: B. McFee and G. R. Lanckriet. "Metric Learning to Rank" ICML 2010. */ public final class BinaryResponsesMeasures { private BinaryResponsesMeasures() {} /** * Computes binary nDCG (i.e. relevance score is 0 or 1) * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return nDCG */ public static double nDCG(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { double dcg = 0.d; double idcg = IDCG(Math.min(recommendSize, groundTruth.size())); for (int i = 0, n = recommendSize; i < n; i++) { Object item_id = rankedList.get(i); if (!groundTruth.contains(item_id)) { continue; } int rank = i + 1; dcg += Math.log(2) / Math.log(rank + 1); } return dcg / idcg; } /** * Computes the ideal DCG * * @param n the number of positive items * @return ideal DCG */ public static double IDCG(final int n) { double idcg = 0.d; for (int i = 0; i < n; i++) { idcg += Math.log(2) / Math.log(i + 2); } return idcg; } /** * Computes Precision@`recommendSize` * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return Precision */ public static double Precision(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { return (double) countTruePositive(rankedList, groundTruth, recommendSize) / recommendSize; } /** * Computes Recall@`recommendSize` * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return Recall */ public static double Recall(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { return (double) countTruePositive(rankedList, groundTruth, recommendSize) / groundTruth.size(); } /** * Counts the number of true positives * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return number of true positives */ public static int countTruePositive(final List<?> rankedList, final List<?> groundTruth, final int recommendSize) { int nTruePositive = 0; for (int i = 0, n = recommendSize; i < n; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { nTruePositive++; } } return nTruePositive; } /** * Computes Mean Reciprocal Rank (MRR) * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return MRR */ public static double MRR(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { for (int i = 0, n = recommendSize; i < n; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { return 1.0 / (i + 1.0); } } return 0.0; } /** * Computes Mean Average Precision (MAP) * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return MAP */ public static double MAP(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { int nTruePositive = 0; double sumPrecision = 0.0; // accumulate precision@1 to @recommendSize for (int i = 0, n = recommendSize; i < n; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { nTruePositive++; sumPrecision += nTruePositive / (i + 1.0); } } return sumPrecision / groundTruth.size(); } /** * Computes the area under the ROC curve (AUC) * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return AUC */ public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { int nTruePositive = 0, nCorrectPairs = 0; // count # of pairs of items that are ranked in the correct order (i.e. TP > FP) for (int i = 0, n = recommendSize; i < n; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { // # of true positives which are ranked higher position than i-th recommended item nTruePositive++; } else { // for each FP item, # of correct ordered <TP, FP> pairs equals to # of TPs at i-th position nCorrectPairs += nTruePositive; } } // # of all possible <TP, FP> pairs int nPairs = nTruePositive * (recommendSize - nTruePositive); // AUC can equivalently be calculated by counting the portion of correctly ordered pairs return (double) nCorrectPairs / nPairs; } }