package edu.umn.cs.recsys.ii;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import edu.umn.cs.recsys.dao.*;
import org.grouplens.lenskit.GlobalItemRecommender;
import org.grouplens.lenskit.GlobalItemScorer;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.RecommenderBuildException;
import org.grouplens.lenskit.core.LenskitConfiguration;
import org.grouplens.lenskit.core.LenskitRecommender;
import org.grouplens.lenskit.data.dao.EventDAO;
import org.grouplens.lenskit.data.dao.ItemDAO;
import org.grouplens.lenskit.data.dao.UserDAO;
import org.grouplens.lenskit.knn.NeighborhoodSize;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.vectors.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* @author <a href="http://www.grouplens.org">GroupLens Research</a>
*/
public class IIMain {
private static final Logger logger = LoggerFactory.getLogger("ii-assignment");
/**
* Main entry point to the program.
* @param args The <tt>user:item</tt> pairs to score.
*/
public static void main(String[] args) {
Map<Long,Set<Long>> toScore = null;
Set<Long> basket = null;
if (args.length == 1 && args[0].equals("--all")) {
logger.info("scoring for all users");
} else if (args.length >= 1 && args[0].equals("--basket")) {
basket = new HashSet<Long>();
for (int i = 1; i < args.length; i++) {
basket.add(Long.parseLong(args[i]));
}
} else {
toScore = parseArgs(args);
}
LenskitConfiguration config = configureRecommender();
LenskitRecommender rec;
try {
rec = LenskitRecommender.build(config);
} catch (RecommenderBuildException e) {
logger.error("error building recommender", e);
System.exit(2);
throw new AssertionError(); // to de-confuse unreachable code detection
}
// Get the item title DAO, so we can look up movie titles
ItemTitleDAO titleDAO = rec.get(ItemTitleDAO.class);
if (basket != null) {
GlobalItemRecommender grec = rec.getGlobalItemRecommender();
logger.info("printing items similar to {}", basket);
List<ScoredId> items = grec.globalRecommend(basket, 5);
for (ScoredId item: items) {
System.out.format(Locale.ROOT, "%d,%.4f,%s\n", item.getId(), item.getScore(),
titleDAO.getItemTitle(item.getId()));
}
return;
}
// Get the item scorer and go!
ItemScorer scorer = rec.getItemScorer();
assert scorer != null;
if (toScore == null) {
logger.debug("loading user/item sets");
UserDAO userDAO = rec.get(UserDAO.class);
if (userDAO == null) {
logger.error("no user DAO");
System.exit(2);
}
toScore = Maps.newHashMap();
for (Long user: userDAO.getUserIds()) {
toScore.put(user, titleDAO.getItemIds());
}
}
logger.info("scoring for {} users", toScore.size());
for (Map.Entry<Long,Set<Long>> scoreRequest: toScore.entrySet()) {
long user = scoreRequest.getKey();
Set<Long> items = scoreRequest.getValue();
logger.info("scoring {} items for user {}", items.size(), user);
// We call the score method that takes a set of items.
// AbstractItemScorer delegates this method to the one you are supposed to implement.
SparseVector scores = scorer.score(user, items);
for (long item: items) {
String score;
if (scores.containsKey(item)) {
score = String.format(Locale.ROOT, "%.4f", scores.get(item));
} else {
score = "NA";
}
String title = titleDAO.getItemTitle(item);
System.out.format("%d,%d,%s,%s\n", user, item, score, title);
}
}
}
/**
* Parse the command line arguments.
* @param args The command line arguments.
* @return A map of users to the sets of items to score for them.
*/
private static Map<Long, Set<Long>> parseArgs(String[] args) {
logger.info("parsing {} command line arguments", args.length);
Pattern pat = Pattern.compile("(\\d+):(\\d+)");
Map<Long, Set<Long>> map = Maps.newHashMap();
for (String arg: args) {
logger.debug("parsing argument: {}", arg);
Matcher m = pat.matcher(arg);
if (m.matches()) {
long uid = Long.parseLong(m.group(1));
long iid = Long.parseLong(m.group(2));
if (!map.containsKey(uid)) {
map.put(uid, Sets.<Long>newHashSet());
}
map.get(uid).add(iid);
} else {
logger.error("unparseable command line argument {}", arg);
}
}
return map;
}
/**
* Create the LensKit recommender configuration.
* @return The LensKit recommender configuration.
*/
// LensKit configuration API generates some unchecked warnings, turn them off
@SuppressWarnings("unchecked")
private static LenskitConfiguration configureRecommender() {
LenskitConfiguration config = new LenskitConfiguration();
// configure the rating data source
config.bind(EventDAO.class)
.to(MOOCRatingDAO.class);
config.set(RatingFile.class)
.to(new File("data/ratings.csv"));
// use custom item and user DAOs
// our item DAO has title information
config.bind(ItemDAO.class)
.to(MOOCItemDAO.class);
config.addRoot(UserDAO.class);
// and title file
config.set(TitleFile.class)
.to(new File("data/movie-titles.csv"));
// our user DAO can look up by user name
config.bind(UserDAO.class)
.to(MOOCUserDAO.class);
config.addRoot(UserDAO.class);
config.set(UserFile.class)
.to(new File("data/users.csv"));
// use the item-item scorer you will implement to score items
config.bind(ItemScorer.class)
.to(SimpleItemItemScorer.class);
config.bind(GlobalItemScorer.class).to(SimpleGlobalItemScorer.class);
config.set(NeighborhoodSize.class)
.to(20);
return config;
}
}