/* * Sifarish: Recommendation Engine * Author: Pranab Ghosh * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package org.sifarish.realtime; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.chombo.storm.Cache; import org.chombo.util.ConfigUtility; import org.chombo.util.Pair; import org.codehaus.jackson.JsonParseException; import org.codehaus.jackson.map.JsonMappingException; import org.codehaus.jackson.map.ObjectMapper; import org.sifarish.common.EngagementToPreferenceMapper; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; /** * Gets predicted ratings for all items correlated with items rated by this user * @author pranab * */ public class UserItemRatings { private String userID; private String sessionID; private Map<String, EngagementEvent> engagementEvents = new HashMap<String, EngagementEvent>(); private LoadingCache<String, List<ItemCorrelation>> itemCorrelationCache = null; private int topItemsCount; private String itemCorrelationKey; private EngagementToPreferenceMapper engaementMapper; private String eventExpirePolicy; private long timedExpireWindowSec; private int countExpireLimit; private boolean debugOn; private static final String EVENT_EXPIRE_SESSION = "session"; private static final String EVENT_EXPIRE_TIME = "time"; private static final String EVENT_EXPIRE_COUNT = "count"; private static final Logger LOG = Logger.getLogger(UserItemRatings.class); /** * @param userID * @param sessionID * @throws IOException * @throws JsonMappingException * @throws JsonParseException */ public UserItemRatings(String userID, String sessionID, Cache cache, Map config) throws Exception { super(); this.userID = userID; this.sessionID = sessionID; //this.jedis = jedis; //config int correlationCacheSize = ConfigUtility.getInt(config,"correlation.cache.size"); int correlationCacheExpiryTimeSec = ConfigUtility.getInt(config,"correlation.cache.expiry.time.sec"); topItemsCount = ConfigUtility.getInt(config,"top.items.count"); itemCorrelationKey = ConfigUtility.getString(config, "redis.item.correlation.key"); String eventMappingMetadataKey = ConfigUtility.getString(config, "redis.event.mapping.metadata.key"); eventExpirePolicy = ConfigUtility.getString(config, "event.expire.policy", EVENT_EXPIRE_SESSION); timedExpireWindowSec = ConfigUtility.getLong(config, "timed.expire.window.sec", -1); countExpireLimit = ConfigUtility.getInt(config,"count.expire.limit", -1); //event mapping metadata String eventMappingMetadata = cache.get(eventMappingMetadataKey); ObjectMapper mapper = new ObjectMapper(); engaementMapper = mapper.readValue(eventMappingMetadata, EngagementToPreferenceMapper.class); //log debugOn = ConfigUtility.getBoolean(config,"debug.on", false); //initialize correlation cache if (null == itemCorrelationCache) { Cache corrCache = Cache.createCache(config, itemCorrelationKey); itemCorrelationCache = CacheBuilder.newBuilder() .maximumSize(correlationCacheSize) .expireAfterAccess(correlationCacheExpiryTimeSec, TimeUnit.SECONDS) .build(new ItemCorrelationLoader(corrCache, itemCorrelationKey, debugOn)); } if (debugOn) { LOG.setLevel(Level.INFO); LOG.info("UserItemRatings intialized"); } } /** * @param sessionID * @param itemID * @param event * @param timestamp */ public void addEvent(String sessionID, String itemID, int event, long timestamp) throws Exception { Set<String> affectedItems = new HashSet<String>(); //add event EngagementEvent engageEvents = engagementEvents.get(itemID); if (null == engageEvents) { engageEvents = new EngagementEvent(itemID, itemCorrelationCache, engaementMapper, debugOn); engagementEvents.put(itemID, engageEvents); } engageEvents.addEvent(event, timestamp); affectedItems.add(itemID); if (debugOn) { LOG.info("event added to EngagementEvent"); } //handle event expiry if (eventExpirePolicy.equals(EVENT_EXPIRE_SESSION)) { //expire by session if (this.sessionID != null && !this.sessionID.equals(sessionID)) { for (String item : engagementEvents.keySet()) { if (engagementEvents.get(item).removeAllEvents()) { affectedItems.add(item); } } this.sessionID = sessionID; } } else if (eventExpirePolicy.equals(EVENT_EXPIRE_TIME)) { //expire by time window if (timedExpireWindowSec < 0) { throw new Exception("For event expiry by time window, timed.expire.window needs to be set"); } for (String item : engagementEvents.keySet()) { if (engagementEvents.get(item).removeOldEventsByTime(timedExpireWindowSec)) { affectedItems.add(item); } } } else if (eventExpirePolicy.equals(EVENT_EXPIRE_COUNT)) { //expire by max event list size if (countExpireLimit < 0) { throw new Exception("For event expiry by count, count.expire.limit needs to be set"); } for (String item : engagementEvents.keySet()) { if (engagementEvents.get(item).removeOldEventsByCount(10)) { affectedItems.add(item); } } } if (debugOn) { LOG.info("Handled event expiry, num of affcted items:" + affectedItems.size()); } } /** * gets predicted ratings * @return * @throws Exception */ public List<ItemRating> getPredictedRatings() throws Exception { List<ItemRating> ratings = new ArrayList<ItemRating>(); Map<String, Integer> itemPredictedRatings = new HashMap<String, Integer>(); Map<String, Integer> itemPredictedRatingCounts = new HashMap<String, Integer>(); if (debugOn) { LOG.info("num of items user " + userID + " engaged with:" + engagementEvents.size()); } //all rated items for (String itemID : engagementEvents.keySet()) { //predicted ratings for items correlated to this item if (debugOn) LOG.info("processing item:" +itemID); EngagementEvent engageEvents = engagementEvents.get(itemID); engageEvents.processRating(); List<ItemRating> thisPredictedRatings = engageEvents.getPredictedRatings(); if (debugOn) LOG.info("for item " + itemID + " there are " + thisPredictedRatings.size() + " correlated items with predicted ratings"); //all correlated items for (ItemRating itemRating : thisPredictedRatings) { //aggregate predicted ratings String item = itemRating.getItem(); Integer rating = itemPredictedRatings.get(item); if (null == rating) { itemPredictedRatings.put(item, itemRating.getRating()); itemPredictedRatingCounts.put(item, 1); } else { itemPredictedRatings.put(itemRating.getItem(), rating + itemRating.getRating()); itemPredictedRatingCounts.put(item, itemPredictedRatingCounts.get(item) + 1); } } } //average predicted rating for (String item : itemPredictedRatings.keySet()) { int avRating = itemPredictedRatings.get(item) / itemPredictedRatingCounts.get(item); ratings.add(new ItemRating(item, avRating)); } if (debugOn) { LOG.info("found net " + ratings.size() + " items with predicted rating"); } //sort and collect top n Collections.sort(ratings); if (ratings.size() > topItemsCount) { ratings.subList(topItemsCount, ratings.size()).clear(); if (debugOn) { LOG.info("picked top k items"); } } return ratings; } /** * Cache loader for item correlation * @author pranab * */ private static class ItemCorrelationLoader extends CacheLoader<String, List<ItemCorrelation>> { private Cache corrCache; private boolean debugOn; private static final Logger LOG = Logger.getLogger(ItemCorrelationLoader.class); public ItemCorrelationLoader(Cache corrCache, String itemCorrelationKey, boolean debugOn) { this.corrCache = corrCache; this.debugOn = debugOn; if (debugOn) LOG.setLevel(Level.INFO); } @Override public List<ItemCorrelation> load(String item) throws Exception { List<ItemCorrelation> itemCorrList = new ArrayList<ItemCorrelation>(); String correlation = corrCache.get(item); correlation = correlation.trim(); if (debugOn) LOG.info("item:" + item + " correlation:" +correlation); String[] parts = correlation.split(","); for (String part : parts) { String[] subParts = part.split(":"); ItemCorrelation itemCorr = new ItemCorrelation(subParts[0], Integer.parseInt(subParts[1])); itemCorrList.add(itemCorr); } return itemCorrList; } } /** * Item and rating * @author pranab * */ public static class ItemRating extends Pair<String, Integer> implements Comparable<ItemRating> { public ItemRating(String itemID, int rating) { super(itemID, rating); } public String getItem() { return getLeft(); } public int getRating() { return getRight(); } public void setRating(int rating) { setRight(rating); } @Override public int compareTo(ItemRating that) { return that.getRight().compareTo(this.getRight()); } public ItemRating cloneItemRating() { return new ItemRating(this.left, this.right); } } /** * Item and correlation * @author pranab * */ public static class ItemCorrelation extends Pair<String, Integer> { public ItemCorrelation(String itemID, int correlation) { super(itemID, correlation); } public String getItem() { return getLeft(); } public int getCorrelation() { return getRight(); } } /** * Engaegement events for an item * @author pranab * */ private static class EngagementEvent { private String item; private List<Pair<Integer, Long>> events = new ArrayList<Pair<Integer, Long>>(); private int currentRating = -1; private List<ItemRating> predictedRatings = new ArrayList<ItemRating>(); private LoadingCache<String, List<ItemCorrelation>> itemCorrelationCache; private EngagementToPreferenceMapper engaementMapper; private boolean debugOn; private static final Logger LOG = Logger.getLogger(EngagementEvent.class); /** * @param item * @param itemCorrelationCache */ public EngagementEvent(String item, LoadingCache<String, List<ItemCorrelation>> itemCorrelationCache, EngagementToPreferenceMapper engaementMapper, boolean debugOn) { super(); this.item = item; this.itemCorrelationCache = itemCorrelationCache; this.engaementMapper = engaementMapper; this.debugOn = debugOn; if (debugOn) { LOG.setLevel(Level.INFO); } } /** * */ public boolean removeAllEvents() { events.clear(); predictedRatings.clear(); currentRating = -1; return true; } /** * @param timedExpireWindow */ public boolean removeOldEventsByTime(long timedExpireWindow) { boolean changed = false; long thresholdTime = System.currentTimeMillis() / 1000 - timedExpireWindow; List<Pair<Integer, Long>> filteredEvents = new ArrayList<Pair<Integer, Long>>(); for (Pair<Integer, Long> event : events) { if (event.getRight() >= thresholdTime) { filteredEvents.add(event); } } if (debugOn) { LOG.info("event count:" + events.size() + " event count after filtering : " + filteredEvents.size()); } if (filteredEvents.size() < events.size()) { events = filteredEvents; predictedRatings.clear(); currentRating = -1; changed = true; } return changed; } /** * @param timedExpireWindow */ public boolean removeOldEventsByCount(int maxCount) { boolean changed = false; if (events.size() > maxCount) { events.remove(0); predictedRatings.clear(); currentRating = -1; changed = true; } return changed; } /** * @param event * @param timestamp */ public void addEvent(int event, long timestamp) { events.add(new Pair<Integer, Long>(event, timestamp)); } /** * @throws Exception */ public void processRating() throws Exception { //if predicted rating list is empty and there are events if (predictedRatings.isEmpty() && !events.isEmpty()) { if (debugOn) { LOG.info("going to find predicted ratings"); } //most engaging event and corresponding count int mostEngaingEvent = 1000000; int eventCount = 0; for (Pair<Integer, Long> event : events) { if (event.getLeft() < mostEngaingEvent) { mostEngaingEvent = event.getLeft(); eventCount = 1; } else if (event.getLeft() == mostEngaingEvent) { ++eventCount; } } //estimate implicit rating int rating = engaementMapper.scoreForEvent(mostEngaingEvent, eventCount); if (debugOn) LOG.info("mostEngaingEvent:" + mostEngaingEvent + " eventCount:" + eventCount + " rating:" + rating ); //predicted ratings only if first time or if rating is better that current if (currentRating < 0 || rating > currentRating) { //get correlated items List<ItemCorrelation> itemCorrs = itemCorrelationCache.get(item); if (debugOn) LOG.info("number of correlated items:" + itemCorrs.size()); //predict ratings for (ItemCorrelation itemCorr : itemCorrs) { ItemRating itemRating = new ItemRating(itemCorr.getItem(), itemCorr.getCorrelation() * rating); predictedRatings.add(itemRating); } currentRating = rating; } } } /** * @return */ public List<ItemRating> getPredictedRatings() { return predictedRatings; } } }