package quickml.supervised.crossValidation.lossfunctions.rankingLossFunctions;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.supervised.rankingModels.ItemToOutcomeMap;
import quickml.supervised.rankingModels.LabelPredictionWeightForRanking;
import quickml.supervised.rankingModels.RankingPrediction;
import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* Created by alexanderhawk on 8/13/15.
*/
public class NDCG implements RankingLossFunction {
private static final Logger logger = LoggerFactory.getLogger(RankingLossFunction.class);
int k = Integer.MAX_VALUE;
public NDCG(int k) {
this.k = k;
}
public NDCG() {
}
/**normailized discounted cumulative gain*/
@Override
public Double getLoss(List<LabelPredictionWeightForRanking> results) {
double loss = 0;
for (LabelPredictionWeightForRanking lpw : results) {
loss+=nDCGForInstance(lpw);
// logger.info("ndcg for instance {}", nDCGForInstance(lpw));
}
return -loss; //need to change this to be negative NDCG
}
@Override
public String getName() {
return "NDCG";
}
private double nDCGForInstance(LabelPredictionWeightForRanking lpw) {
double dcg =dcg(lpw, k);
double idcg = idcg(lpw, k);
return lpw.getWeight()*dcg/idcg;
}
public static double dcg(LabelPredictionWeightForRanking lpw, int k) {
ItemToOutcomeMap ito = lpw.getLabel();
RankingPrediction rp = lpw.getPrediction();
double dcg = 0;
for (Serializable item : ito.getItems()) {
double outcome = ito.getOutcome(item);
int rank = rp.getRankOfItem(item);
if (rank < k) {
dcg += dcgSummand(outcome, rank);
// logger.info("ranked + " +rank);
}
else{
// System.out.println("not ranked");
}
}
return dcg;
}
public static double dcgSummand(double outcome, int rank) {
double numerator = Math.pow(2,outcome) -1;
double denominator = Math.log(1+rank)/Math.log(2);
return numerator/denominator;
}
public static double idcg(LabelPredictionWeightForRanking lpw, int k) {
ItemToOutcomeMap ito = lpw.getLabel();
List<Double> outcomes = Lists.newArrayList(ito.getOutcomes());
//sort descending order
if (outcomes.size()==1) {
return dcgSummand(outcomes.get(0), 1);
}
Collections.sort(outcomes, new Comparator<Double>() {
@Override
public int compare(Double o1, Double o2) {
return -Double.compare(o1, o2);
}
});
double idcg = 0;
for (int i = 0; i < outcomes.size(); i++) {
if (i>k) {
break;
}
Double outcome = outcomes.get(i);
int rank = i + 1;
idcg += dcgSummand(outcome, rank);
}
return idcg;
}
}