/* * Copyright (C) 2012 Sebastian Schelter <sebastian.schelter [at] tu-berlin.de> * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. */ package de.tuberlin.dima.recsys.ssnmm.interactioncut; import org.apache.mahout.cf.taste.common.TasteException; import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; import org.apache.mahout.cf.taste.impl.common.RunningAverage; import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender; import org.apache.mahout.cf.taste.model.DataModel; import org.apache.mahout.cf.taste.model.Preference; import org.apache.mahout.cf.taste.model.PreferenceArray; import org.apache.mahout.cf.taste.similarity.ItemSimilarity; import org.apache.mahout.math.map.OpenLongDoubleHashMap; /** * Itembased recommender that uses weighted sum estimation enhanced by baseline estimates */ public class BiasedItemBasedRecommender extends GenericItemBasedRecommender { private final int k; private final double mu; private final OpenLongDoubleHashMap itemBiases; private final OpenLongDoubleHashMap userBiases; private final ItemSimilarity similarity; public BiasedItemBasedRecommender(DataModel dataModel, ItemSimilarity similarity, int k, double lambda2, double lambda3) throws TasteException { super(dataModel, similarity); this.k = k; this.similarity = similarity; RunningAverage averageRating = new FullRunningAverage(); LongPrimitiveIterator itemIDs = getDataModel().getItemIDs(); while (itemIDs.hasNext()) { for (Preference pref : getDataModel().getPreferencesForItem(itemIDs.next())) { averageRating.addDatum(pref.getValue()); } } mu = averageRating.getAverage(); itemBiases = new OpenLongDoubleHashMap(getDataModel().getNumItems()); userBiases = new OpenLongDoubleHashMap(getDataModel().getNumUsers()); itemIDs = getDataModel().getItemIDs(); while (itemIDs.hasNext()) { long itemID = itemIDs.nextLong(); PreferenceArray preferences = getDataModel().getPreferencesForItem(itemID); double sum = 0; for (Preference pref : preferences) { sum += pref.getValue() - mu; } double bi = sum / (lambda2 + preferences.length()); itemBiases.put(itemID, bi); } LongPrimitiveIterator userIDs = getDataModel().getUserIDs(); while (userIDs.hasNext()) { long userID = userIDs.nextLong(); PreferenceArray preferences = getDataModel().getPreferencesFromUser(userID); double sum = 0; for (Preference pref : preferences) { sum += pref.getValue() - mu - itemBiases.get(pref.getItemID()); } double bu = sum / (lambda3 + preferences.length()); userBiases.put(userID, bu); } } @Override public float estimatePreference(long userID, long itemID) throws TasteException { PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID); Float actualPref = getPreferenceForItem(preferencesFromUser, itemID); if (actualPref != null) { return actualPref; } return doEstimatePreference(userID, preferencesFromUser, itemID); } private static Float getPreferenceForItem(PreferenceArray preferencesFromUser, long itemID) { int size = preferencesFromUser.length(); for (int i = 0; i < size; i++) { if (preferencesFromUser.getItemID(i) == itemID) { return preferencesFromUser.getValue(i); } } return null; } protected double baselineEstimate(long userID, long itemID) throws TasteException { return mu + userBiases.get(userID) + itemBiases.get(itemID); } @Override protected float doEstimatePreference(long userID, PreferenceArray preferencesFromUser, long itemID) throws TasteException { double preference = 0.0; double totalSimilarity = 0.0; int count = 0; long[] userIDs = preferencesFromUser.getIDs(); float[] ratings = new float[userIDs.length]; long[] itemIDs = new long[userIDs.length]; double[] similarities = similarity.itemSimilarities(itemID, userIDs); for (int n = 0; n < preferencesFromUser.length(); n++) { ratings[n] = preferencesFromUser.get(n).getValue(); itemIDs[n] = preferencesFromUser.get(n).getItemID(); } quickSort(similarities, ratings, itemIDs, 0, (similarities.length - 1)); for (int i = 0; i < Math.min(k, similarities.length); i++) { double theSimilarity = similarities[i]; if (!Double.isNaN(theSimilarity)) { preference += theSimilarity * (ratings[i] - baselineEstimate(userID, itemIDs[i])); totalSimilarity += theSimilarity; count++; } } if (count <= 1) { return Float.NaN; } float estimate = (float) (baselineEstimate(userID, itemID) + (preference / totalSimilarity)); return estimate; } protected void quickSort(double[] similarities, float[] values, long[] otherValues, int start, int end) { if (start < end) { double pivot = similarities[end]; float pivotValue = values[end]; int i = start; int j = end; while (i != j) { if (similarities[i] > pivot) { i = i + 1; } else { similarities[j] = similarities[i]; values[j] = values[i]; otherValues[j] = otherValues[i]; similarities[i] = similarities[j - 1]; values[i] = values[j - 1]; otherValues[i] = otherValues[j - 1]; j = j - 1; } } similarities[j] = pivot; values[j] = pivotValue ; quickSort(similarities, values, otherValues, start, j - 1); quickSort(similarities, values, otherValues, j + 1, end); } } }