package happy.research.cf; import happy.coding.io.FileIO; import happy.coding.io.FileIO.MapWriter; import happy.coding.io.KeyValPair; import happy.coding.io.Lists; import happy.coding.io.Strings; import happy.coding.math.Sims; import java.io.File; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.junit.Test; public class RecSysCourse_mt extends DefaultCF_mt { private final static String dir = "D:\\Dropbox\\PhD\\My Work\\Algorithms\\@Machine Learning\\RecSys\\Assignments\\"; public RecSysCourse_mt() { methodId = "RecSys-Course CF"; } @Override protected Performance runMultiThreads() throws Exception { A4(); return null; } protected void A4() throws Exception { String dirPath = dir + "A4" + File.separator; Map<String, Map<String, Double>> corrs = new HashMap<>(); for (String a : userRatingsMap.keySet()) { Map<String, Double> ucs = new HashMap<>(); Map<String, Rating> asRatings = userRatingsMap.get(a); for (String b : userRatingsMap.keySet()) { if (a == b) { ucs.put(b, 1.0); } else { Map<String, Rating> bsRatings = userRatingsMap.get(b); List<Double> as = new ArrayList<>(); List<Double> bs = new ArrayList<>(); for (String ai : asRatings.keySet()) { if (bsRatings.containsKey(ai)) { as.add(asRatings.get(ai).getRating()); bs.add(bsRatings.get(ai).getRating()); } } if (as.size() >= 2) { double corr = Sims.pcc(as, bs); if (!Double.isNaN(corr)) ucs.put(b, corr); } } } corrs.put(a, ucs); } String testA1 = "1648"; String testA2 = "5136"; String testB1 = "918"; String testB2 = "2824"; float ca = corrs.get(testA1).get(testA2).floatValue(); float cb = corrs.get(testB1).get(testB2).floatValue(); assert ca == 0.40298; assert cb == -0.31706; String[] users = {/* examples */"3712", /* tests */"3867", "860" }; String file1 = dirPath + "results-part-1.txt"; String file2 = dirPath + "results-part-2.txt"; FileIO.deleteFile(file1); FileIO.deleteFile(file2); for (int k = 0; k < users.length; k++) { String user = users[k]; Map<String, Double> nns = corrs.get(user); List<KeyValPair<String>> sorted = Lists.sortMap(nns, true); int knn = 0; List<KeyValPair<String>> found = new ArrayList<>(); for (KeyValPair<String> sp : sorted) { String u = sp.getKey(); if (u.equals(user)) continue; found.add(sp); if (k == 0) System.out.println(sp.getKey() + ":" + sp.getValue()); knn++; if (knn >= 5) break; } // do predictions: part I // Map<String, Rating> usRatings = userRatingsMap.get(user); Map<String, Double> itemPreds = new HashMap<>(); for (String item : itemRatingsMap.keySet()) { // if (usRatings.containsKey(item)) continue; double sum = 0; double val = 0; for (KeyValPair<String> sp : found) { String v = sp.getKey(); double c = sp.getValue(); if (userRatingsMap.get(v).containsKey(item)) { val += userRatingsMap.get(v).get(item).getRating() * c; sum += c; } } if (sum > 0) { double pred = val / sum; itemPreds.put(item, pred); } } List<KeyValPair<String>> recs = Lists.sortMap(itemPreds, true); List<String> lines = new ArrayList<>(); int recNum = 6; int cnt = 0; for (int m = 0; m < recs.size(); m++) { KeyValPair<String> rec = recs.get(m); String val = Strings.toString(rec.getValue(), 3); String line = rec.getKey() + " " + val; cnt++; if (cnt > recNum) break; lines.add(line); } String content = Strings.toString(lines); if (k == 0) System.out.println(content); else FileIO.writeString(file1, content, true); // do prediction - part 2 itemPreds.clear(); Map<String, Rating> usRatings = userRatingsMap.get(user); double mu = RatingUtils.mean(usRatings, null); for (String item : itemRatingsMap.keySet()) { // if (usRatings.containsKey(item)) continue; double sum = 0; double val = 0; for (KeyValPair<String> sp : found) { String v = sp.getKey(); double c = sp.getValue(); Map<String, Rating> vsRatings = userRatingsMap.get(v); if (vsRatings.containsKey(item)) { double mv = RatingUtils.mean(vsRatings, null); double rate = vsRatings.get(item).getRating(); val += (rate - mv) * c; sum += c; } } if (sum > 0) { double pred = mu + val / sum; itemPreds.put(item, pred); } } recs = Lists.sortMap(itemPreds, true); lines.clear(); cnt = 0; for (int m = 0; m < recs.size(); m++) { KeyValPair<String> rec = recs.get(m); String val = Strings.toString(rec.getValue(), 3); String line = rec.getKey() + " " + val; cnt++; if (cnt > recNum) break; lines.add(line); } content = Strings.toString(lines); if (k == 0) System.out.println(content); else FileIO.writeString(file2, content, true); } } @Test public void convertData() throws Exception { String dirPath = dir + "A4" + File.separator; String source = dirPath + "recsys-data-sample-rating-matrix.csv"; List<String> content = FileIO.readAsList(source); List<String> users = new ArrayList<>(); Map<String, Map<String, Double>> data = new HashMap<>(); String head = content.get(0); String[] vals = head.split(","); for (String user : vals) { if (!user.equals("\"\"")) { user = user.replace("\"", ""); users.add(user); data.put(user, new HashMap<String, Double>()); } } List<String> items = new ArrayList<>(); for (int i = 1; i < content.size(); i++) { String line = content.get(i); vals = line.split(","); String item = vals[0].substring(1, vals[0].indexOf(":")); items.add(item); for (int j = 1; j < vals.length; j++) { String rate = vals[j]; if (!rate.equals("")) { String user = users.get(j - 1); Map<String, Double> irs = data.get(user); irs.put(item, Double.parseDouble(rate)); data.put(user, irs); } } } String filePath = dirPath + "ratings.txt"; FileIO.deleteFile(filePath); for (final String user : users) { Map<String, Double> irs = data.get(user); FileIO.writeMap(filePath, irs, new MapWriter<String, Double>() { @Override public String processEntry(String key, Double val) { return user + " " + key + " " + val.floatValue(); } }, true); } } }