package edu.umn.cs.recsys.svd; import com.google.common.base.Preconditions; import org.apache.commons.math3.linear.RealMatrix; import org.grouplens.grapht.annotation.DefaultProvider; import org.grouplens.lenskit.core.Shareable; import org.grouplens.lenskit.indexes.IdIndexMapping; import javax.annotation.Nullable; import java.io.Serializable; /** * SVD model for collaborative filtering. */ @Shareable @DefaultProvider(SVDModelBuilder.class) public class SVDModel implements Serializable { private static final long serialVersionUID = 1L; private final IdIndexMapping userMapping; private final IdIndexMapping itemMapping; private final RealMatrix userFeatureMatrix; private final RealMatrix itemFeatureMatrix; private final RealMatrix featureWeights; /** * Construct an SVD model. The matrices represent the decomposition, such that the predictions * are equal to {@code umat * weights * imat.transpose()}. * * @param umap The mapping between user IDs and row numbers. * @param imap The mapping between item IDs and row numbers. * @param umat The user feature matrix (users x features) * @param imat The item feature matrix (items x features) * @param weights The singular value matrix (diagonal matrix, features x features) */ SVDModel(IdIndexMapping umap, IdIndexMapping imap, RealMatrix umat, RealMatrix imat, RealMatrix weights) { Preconditions.checkArgument(weights.isSquare(), "singular value matrix is not square"); Preconditions.checkArgument(umat.getColumnDimension() == weights.getRowDimension(), "user matrix has incorrect column dimension"); Preconditions.checkArgument(imat.getColumnDimension() == weights.getColumnDimension(), "item matrix has incorrect column dimension"); userMapping = umap; itemMapping = imap; userFeatureMatrix = umat; itemFeatureMatrix = imat; featureWeights = weights; } /** * Get the feature weights. This is a diagonal matrix. * @return The diagonal matrix of feature weights. */ public RealMatrix getFeatureWeights() { return featureWeights; } /** * Get a user feature vector. This is a row vector whose values (columns) are the feature * values for a particular user. * * @param user The user ID. * @return The feature vector for user {@code user}, or {@code null} if the user is unkonwn. */ @Nullable public RealMatrix getUserVector(long user) { int row = userMapping.tryGetIndex(user); if (row >= 0) { return userFeatureMatrix.getRowMatrix(row); } else { return null; } } /** * Get a item feature vector. This is a row vector whose values (columns) are the feature * values for a particular item. * * * @param item The item ID. * @return The feature vector for item {@code item}. */ public RealMatrix getItemVector(long item) { int row = itemMapping.tryGetIndex(item); if (row >= 0) { return itemFeatureMatrix.getRowMatrix(row); } else { return null; } } /** * Get a item feature vector matrix. Its rows are items and its columns are latent features. * * @return The item-feature matrix (this must not be modified). */ public RealMatrix getItemFeatureMatrix() { return itemFeatureMatrix; } /** * Get the user index mapping. * @return The mapping between user IDs and matrix row numbers. */ public IdIndexMapping getUserIndexMapping() { return userMapping; } /** * Get the item index mapping. * @return The mapping between item IDs and matrix row numbers. */ public IdIndexMapping getItemIndMapping() { return itemMapping; } /** * Get the row number for an item in the item-feature matrix. * @param item The item ID. * @return The row number for the item. * @throws IllegalArgumentException if the item is unknown. */ public int getItemRow(long item) { return itemMapping.getIndex(item); } }