package edu.umn.cs.recsys.svd; import org.apache.commons.math3.linear.RealMatrix; import org.grouplens.lenskit.ItemScorer; import org.grouplens.lenskit.baseline.BaselineScorer; import org.grouplens.lenskit.basic.AbstractItemScorer; import org.grouplens.lenskit.data.dao.UserEventDAO; import org.grouplens.lenskit.data.event.Rating; import org.grouplens.lenskit.data.history.History; import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer; import org.grouplens.lenskit.data.history.UserHistory; import org.grouplens.lenskit.vectors.MutableSparseVector; import org.grouplens.lenskit.vectors.SparseVector; import org.grouplens.lenskit.vectors.VectorEntry; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; import javax.inject.Inject; /** * SVD-based item scorer. */ public class SVDItemScorer extends AbstractItemScorer { private static final Logger logger = LoggerFactory .getLogger(SVDItemScorer.class); private final SVDModel model; private final ItemScorer baselineScorer; private final UserEventDAO userEvents; /** * Construct an SVD item scorer using a model. * * @param m * The model to use when generating scores. * @param uedao * A DAO to get user rating profiles. * @param baseline * The baseline scorer (providing means). */ @Inject public SVDItemScorer(SVDModel m, UserEventDAO uedao, @BaselineScorer ItemScorer baseline) { model = m; baselineScorer = baseline; userEvents = uedao; } /** * Score items in a vector. The key domain of the provided vector is the * items to score, and the score method sets the values for each item to its * score (or unsets it, if no score can be provided). The previous values * are discarded. * * @param user * The user ID. * @param scores * The score vector. */ @Override public void score(long user, @Nonnull MutableSparseVector scores) { // P = b + U.S.Vt if (model.getUserVector(user) == null) { scores.clear(); } else { RealMatrix U = model.getUserVector(user); RealMatrix S = model.getFeatureWeights(); for (VectorEntry e : scores.fast(VectorEntry.State.EITHER)) { long item = e.getKey(); RealMatrix V = model.getItemVector(item); scores.set(item, baselineScorer.score(user, item) + (U.multiply(S)).multiply(V.transpose()) .getEntry(0, 0)); } } } /** * Get a user's ratings. * * @param user * The user ID. * @return The ratings to retrieve. */ private SparseVector getUserRatingVector(long user) { UserHistory<Rating> history = userEvents.getEventsForUser(user, Rating.class); if (history == null) { history = History.forUser(user); } return RatingVectorUserHistorySummarizer.makeRatingVector(history); } }