/*
* Copyright (C) 2016 RankSys http://ranksys.org
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
*/
package org.ranksys.lda;
import cc.mallet.topics.ParallelTopicModel;
import es.uam.eps.ir.ranksys.fast.FastRecommendation;
import es.uam.eps.ir.ranksys.fast.index.FastItemIndex;
import es.uam.eps.ir.ranksys.fast.index.FastUserIndex;
import es.uam.eps.ir.ranksys.fast.utils.topn.IntDoubleTopN;
import es.uam.eps.ir.ranksys.rec.fast.AbstractFastRecommender;
import static java.lang.Math.min;
import java.util.List;
import java.util.function.IntPredicate;
import static java.util.stream.Collectors.toList;
import org.ranksys.core.util.tuples.Tuple2id;
/**
* LDA recommender. See ParallelTopicModel in Mallet (http://mallet.cs.umass.edu/) for more details.
*
* @author Saúl Vargas (Saul.Vargas@glasgow.ac.uk)
*
* @param <U> user type
* @param <I> item type
*/
public class LDARecommender<U, I> extends AbstractFastRecommender<U, I> {
private final ParallelTopicModel topicModel;
/**
* Constructor
*
* @param uIndex user index
* @param iIndex item index
* @param topicModel LDA topic model
*/
public LDARecommender(FastUserIndex<U> uIndex, FastItemIndex<I> iIndex, ParallelTopicModel topicModel) {
super(uIndex, iIndex);
this.topicModel = topicModel;
}
@Override
public FastRecommendation getRecommendation(int uidx, int maxLength, IntPredicate filter) {
IntDoubleTopN topN = new IntDoubleTopN(min(maxLength, numItems()));
for (int iidx = 0; iidx < numItems(); iidx++) {
if (filter.test(iidx)) {
topN.add(iidx, score(topicModel, uidx, iidx));
}
}
topN.sort();
List<Tuple2id> items = topN.reverseStream()
.collect(toList());
return new FastRecommendation(uidx, items);
}
private double score(ParallelTopicModel topicModel, int uidx, int iidx) {
double[] pu = topicModel.getTopicProbabilities(uidx);
int[] qi = topicModel.typeTopicCounts[iidx];
double score = 0.0;
int i = 0;
while (i < qi.length && qi[i] > 0) {
int z = qi[i] & topicModel.topicMask;
int n = qi[i] >> topicModel.topicBits;
score += pu[z] * (n / (double) topicModel.tokensPerTopic[z]);
i++;
}
return score;
}
}