package edu.umn.cs.recsys.svd;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import edu.umn.cs.recsys.dao.*;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.core.LenskitConfiguration;
import org.grouplens.lenskit.core.LenskitRecommender;
import org.grouplens.lenskit.data.dao.EventDAO;
import org.grouplens.lenskit.data.dao.ItemDAO;
import org.grouplens.lenskit.data.dao.UserDAO;
import org.grouplens.lenskit.vectors.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* @author <a href="http://www.grouplens.org">GroupLens Research</a>
*/
public class SVDMain {
private static final Logger logger = LoggerFactory.getLogger("ii-assignment");
private static final Pattern USER_ITEM_PAT = Pattern.compile("(\\d+):(\\d+)");
/**
* Main entry point to the program.
* @param args The <tt>user:item</tt> pairs to score.
*/
public static void main(String[] args) {
SVDMain program = initialize(args);
program.run();
}
/**
* Parse arguments and set up an SVD runner.
* @param args The command line arguments.
* @return The SVD program, configured and ready to run.
*/
public static SVDMain initialize(String[] args) {
BaselineMode baselineMode = BaselineMode.GLOBAL_MEAN;
Map<Long,Set<Long>> toScore = Maps.newHashMap();
for (String arg: args) {
logger.debug("parsing argument: {}", arg);
if (arg.equals("--global-mean")) {
baselineMode = BaselineMode.GLOBAL_MEAN;
} else if (arg.equals("--user-mean")) {
baselineMode = BaselineMode.USER_MEAN;
} else if (arg.equals("--item-mean")) {
baselineMode = BaselineMode.ITEM_MEAN;
} else if (arg.equals("--user-item-mean")) {
baselineMode = BaselineMode.USER_ITEM_MEAN;
} else if (arg.equals("--all")) {
toScore = null;
} else if (arg.startsWith("--")) {
throw new IllegalArgumentException("unknown flag " + arg);
} else {
Matcher m = USER_ITEM_PAT.matcher(arg);
if (m.matches()) {
long uid = Long.parseLong(m.group(1));
long iid = Long.parseLong(m.group(2));
if (!toScore.containsKey(uid)) {
toScore.put(uid, Sets.<Long>newHashSet());
}
toScore.get(uid).add(iid);
} else {
throw new IllegalArgumentException("unparseable argument " + arg);
}
}
}
return new SVDMain(baselineMode, toScore);
}
BaselineMode baselineMode;
Map<Long,Set<Long>> toScore;
/**
* Construct a new SVD program.
* @param base The baseline mode.
* @param requests The items to score for each user.
*/
public SVDMain(BaselineMode base, Map<Long,Set<Long>> requests) {
baselineMode = base;
toScore = requests;
}
/**
* Create the LensKit recommender configuration.
* @return The LensKit recommender configuration.
*/
// LensKit configuration API generates some unchecked warnings, turn them off
@SuppressWarnings("unchecked")
private LenskitConfiguration configureRecommender() {
LenskitConfiguration config = new LenskitConfiguration();
// configure the rating data source
config.bind(EventDAO.class)
.to(MOOCRatingDAO.class);
config.set(RatingFile.class)
.to(new File("data/ratings.csv"));
// use custom item and user DAOs
// our item DAO has title information
config.bind(ItemDAO.class)
.to(MOOCItemDAO.class);
config.addRoot(UserDAO.class);
// and title file
config.set(TitleFile.class)
.to(new File("data/movie-titles.csv"));
// our user DAO can look up by user name
config.bind(UserDAO.class)
.to(MOOCUserDAO.class);
config.addRoot(UserDAO.class);
config.set(UserFile.class)
.to(new File("data/users.csv"));
// use the item-item scorer you will implement to score items
config.bind(ItemScorer.class)
.to(SVDItemScorer.class);
baselineMode.configure(config);
config.set(LatentFeatureCount.class)
.to(10);
return config;
}
public void run() {
LenskitConfiguration config = configureRecommender();
LenskitRecommender rec;
try {
rec = LenskitRecommender.build(config);
} catch (RecommenderBuildException e) {
logger.error("error building recommender", e);
System.exit(2);
throw new AssertionError(); // to de-confuse unreachable code detection
}
// Get the item title DAO, so we can look up movie titles
ItemTitleDAO titleDAO = rec.get(ItemTitleDAO.class);
// Get the item scorer and go!
ItemScorer scorer = rec.getItemScorer();
assert scorer != null;
if (toScore == null) {
logger.debug("loading user/item sets");
UserDAO userDAO = rec.get(UserDAO.class);
if (userDAO == null) {
logger.error("no user DAO");
System.exit(2);
}
toScore = Maps.newHashMap();
for (Long user: userDAO.getUserIds()) {
toScore.put(user, titleDAO.getItemIds());
}
}
logger.info("scoring for {} users", toScore.size());
for (Map.Entry<Long,Set<Long>> scoreRequest: toScore.entrySet()) {
long user = scoreRequest.getKey();
Set<Long> items = scoreRequest.getValue();
logger.info("scoring {} items for user {}", items.size(), user);
// We call the score method that takes a set of items.
// AbstractItemScorer delegates this method to the one you are supposed to implement.
SparseVector scores = scorer.score(user, items);
for (long item: items) {
String score;
if (scores.containsKey(item)) {
score = String.format(Locale.ROOT, "%.4f", scores.get(item));
} else {
score = "NA";
}
String title = titleDAO.getItemTitle(item);
System.out.format("%d,%d,%s,%s\n", user, item, score, title);
}
}
}
}