/* * Copyright (C) 2012 Sebastian Schelter <sebastian.schelter [at] tu-berlin.de> * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ package de.tuberlin.dima.recsys.ssnmm.ratingprediction; import com.google.common.base.Charsets; import com.google.common.io.Closeables; import com.google.common.io.Files; import de.tuberlin.dima.recsys.ssnmm.Rating; import de.tuberlin.dima.recsys.ssnmm.RatingsIterable; import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; import org.apache.mahout.cf.taste.impl.common.RunningAverage; import java.io.BufferedWriter; import java.io.File; import java.io.IOException; import java.util.Arrays; /** * Java port of the "UserItemBaseline" rating predictor from "mymedialite" https://github.com/zenogantner/MyMediaLite/ */ public class UserItemBaseline { public static void main(String[] args) throws IOException { File trainingFile = new File("/home/ssc/Entwicklung/datasets/yahoo-songs/songs.tsv"); File testFile = new File("/home/ssc/Entwicklung/datasets/yahoo-songs/holdout.tsv"); File outputDir = new File("/home/ssc/Desktop/yahoo/"); int numUsers = 1823179; int numItems = 136736; double mu = 3.157255412010664; int numIterations = 3; UserItemBaseline baseline = new UserItemBaseline(trainingFile, testFile, 0.5, 0, numUsers, numItems, mu); for (int n = 0; n < numIterations; n++) { baseline.train(); } baseline.test(); baseline.persistBiases(outputDir); } private double[] userBiases; private double[] itemBiases; private double globalAverage; private final File ratings; private final File tests; private final double regI; private final double regU; public UserItemBaseline(File ratings, File tests, double regU, double regI, int numUsers, int numItems, double mu) { this.ratings = ratings; this.tests = tests; this.regU = regU; this.regI = regI; globalAverage = mu; userBiases = new double[numUsers]; Arrays.fill(userBiases, 0); itemBiases = new double[numItems]; Arrays.fill(itemBiases, 0); } void test() throws IOException { RunningAverage rmse = new FullRunningAverage(); RunningAverage mae = new FullRunningAverage(); System.out.println("Calculating predictions"); for (Rating rating : new RatingsIterable(tests)) { double error = Math.abs(rating.rating() - baselineEstimate(rating.user(), rating.item())); mae.addDatum(error); rmse.addDatum(error * error); } System.out.println("MAE " + mae.getAverage() + ", RMSE: " + Math.sqrt(rmse.getAverage())); } double baselineEstimate(int user, int item) { return globalAverage + userBiases[user] + itemBiases[item]; } void train() throws IOException { optimizeItemBiases(); optimizeUserBiases(); } void optimizeItemBiases() throws IOException { System.out.println("Optimizing item biases..."); int[] itemRatingsCount = new int[itemBiases.length]; Arrays.fill(itemRatingsCount, 0); int ratingsProcessed = 0; for (Rating rating : new RatingsIterable(ratings)) { itemBiases[rating.item()] += rating.rating() - globalAverage - userBiases[rating.user()]; itemRatingsCount[rating.item()]++; if (++ratingsProcessed % 10000000 == 0) { System.out.println((ratingsProcessed / 1000000) + "M ratings processed"); } } for (int item = 0; item < itemBiases.length; item++) { if (itemRatingsCount[item] != 0) { itemBiases[item] /= regI + itemRatingsCount[item]; } } } void optimizeUserBiases() throws IOException { System.out.println("Optimizing user biases..."); int[] userRatingsCount = new int[userBiases.length]; Arrays.fill(userRatingsCount, 0); int ratingsProcessed = 0; for (Rating rating : new RatingsIterable(ratings)) { userBiases[rating.user()] += rating.rating() - globalAverage - itemBiases[rating.item()]; userRatingsCount[rating.user()]++; if (++ratingsProcessed % 10000000 == 0) { System.out.println((ratingsProcessed / 1000000) + "M ratings processed"); } } for (int user = 0; user < userBiases.length; user++) { if (userRatingsCount[user] != 0) { userBiases[user] /= regU + userRatingsCount[user]; } } } void persistBiases(File dir) throws IOException { persist(new File(dir, "userBiases.tsv"), userBiases); persist(new File(dir, "itemBiases.tsv"), itemBiases); } private void persist(File file, double[] biases) throws IOException { BufferedWriter writer = null; try { writer = Files.newWriter(file, Charsets.UTF_8); for (int index = 0; index < biases.length; index++) { writer.append(String.valueOf(index)); writer.append("\t"); writer.append(String.valueOf(biases[index])); writer.append("\n"); } } finally { Closeables.closeQuietly(writer); } } }