/* * Seldon -- open source prediction engine * ======================================= * * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * * ******************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * ******************************************************************************************** */ package io.seldon.topics; import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStreamReader; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import io.seldon.mf.PerClientExternalLocationListener; import io.seldon.recommendation.model.ModelManager; import io.seldon.resources.external.ExternalResourceStreamer; import io.seldon.resources.external.NewResourceNotifier; import org.apache.log4j.Logger; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; @Component public class TopicFeaturesManager extends ModelManager<TopicFeaturesManager.TopicFeaturesStore> { private static Logger logger = Logger.getLogger(TopicFeaturesManager.class.getName()); private final ExternalResourceStreamer featuresFileHandler; public static final String TOPIC_NEW_LOC_PATTERN = "topics"; @Autowired public TopicFeaturesManager(ExternalResourceStreamer featuresFileHandler, NewResourceNotifier notifier){ super(notifier, Collections.singleton(TOPIC_NEW_LOC_PATTERN)); this.featuresFileHandler = featuresFileHandler; } private Map<String,Map<Integer,Float>> readTopicFeatures(BufferedReader reader) throws IOException { Map<String,Map<Integer,Float>> toReturn = new HashMap<>(); String line; while((line = reader.readLine()) !=null){ String[] parts = line.split(","); Integer topic = Integer.parseInt(parts[0]); String keyword = parts[1]; Float weight = Float.parseFloat(parts[2]); Map<Integer,Float> topicToWeight = toReturn.get(keyword); if (topicToWeight == null) topicToWeight = new HashMap<>(); topicToWeight.put(topic, weight); toReturn.put(keyword, topicToWeight); } return toReturn; } private Map<Long,Map<Integer,Float>> readUserFeatures(BufferedReader reader) throws IOException { Map<Long,Map<Integer,Float>> toReturn = new HashMap<>(); String line; while((line = reader.readLine()) !=null){ Map<Integer, Float> topicMap = new HashMap<>(); String[] userAndTopics = line.split(","); Long user = Long.parseLong(userAndTopics[0]); for (int i = 1; i < userAndTopics.length; i++){ String[] topicAndWeight = userAndTopics[i].split(":"); topicMap.put(Integer.parseInt(topicAndWeight[0]), Float.parseFloat(topicAndWeight[1])); } toReturn.put(user, topicMap); } return toReturn; } @Override protected TopicFeaturesStore loadModel(String location, String client) { logger.info("Reloading topic features for client: "+ client); try { BufferedReader userFeaturesReader = new BufferedReader(new InputStreamReader( featuresFileHandler.getResourceStream(location + "/users.csv") )); Map<Long,Map<Integer,Float>> userFeatures = readUserFeatures(userFeaturesReader); logger.info("Loaded user features for client "+client+" with map of size "+userFeatures.size()); BufferedReader topicFeaturesReader = new BufferedReader(new InputStreamReader( featuresFileHandler.getResourceStream(location + "/topics.csv") )); Map<String,Map<Integer,Float>> topicFeatures = readTopicFeatures(topicFeaturesReader); logger.info("Loaded topic features for client "+client+" with map of size "+topicFeatures.size()); userFeaturesReader.close(); topicFeaturesReader.close(); logger.info("finished load of topic features for client "+client); return new TopicFeaturesStore(userFeatures, topicFeatures); } catch (FileNotFoundException e) { logger.error("Couldn't reloadFeatures for client "+ client, e); } catch (IOException e) { logger.error("Couldn't reloadFeatures for client "+ client, e); } return null; } public static class TopicWeights { float[] weights; long created; public TopicWeights(float[] weights, long created) { super(); this.weights = weights; this.created = created; } } public static class TopicFeaturesStore { int numTopics; Map<Long,Map<Integer,Float>> userTopicWeights; Map<String,Map<Integer,Float>> tagTopicWeights; Map<Long,TopicWeights> itemTopicWeights = new ConcurrentHashMap<>(); public TopicFeaturesStore( Map<Long, Map<Integer, Float>> userTopicWeights, Map<String, Map<Integer, Float>> tagTopicWeights) { super(); this.userTopicWeights = userTopicWeights; this.tagTopicWeights = tagTopicWeights; int maxTopic = 0; for(Map<Integer,Float> v : tagTopicWeights.values()) for (Integer topic : v.keySet()) if (topic > maxTopic) maxTopic = topic; for(Map<Integer,Float> v : userTopicWeights.values()) for (Integer topic : v.keySet()) if (topic > maxTopic) maxTopic = topic; this.numTopics = maxTopic + 1; } public float[] getUserWeightVector(Long userId) { Map<Integer,Float> topicWeights = userTopicWeights.get(userId); if (topicWeights != null) { float[] v = new float[numTopics]; for(Map.Entry<Integer, Float> e : topicWeights.entrySet()) v[e.getKey()] = e.getValue(); return v; } else return null; } public int getNumTopics() { return numTopics; } public float[] getTopicWeights(Long key,List<String> tags) { TopicWeights tw = itemTopicWeights.get(key); if (tw != null) return tw.weights; else { float[] weights = predictTopicWeights(tags); tw = new TopicWeights(weights, System.currentTimeMillis()); itemTopicWeights.put(key, tw); return weights; } } //Simplistic average of topic weights in tags // Need to change to proper online LDA approx private float[] predictTopicWeights(List<String> tags) { float[] score = new float[numTopics]; float sum = 0.0f; for(String tag : tags) { Map<Integer,Float> topicWeightMap = tagTopicWeights.get(tag); if (topicWeightMap != null) // if we have a topic weight for this tag { for(Map.Entry<Integer, Float> topicWeight : topicWeightMap.entrySet()) { score[topicWeight.getKey()] += topicWeight.getValue(); sum = sum + topicWeight.getValue(); } } } if (sum > 0) for(int i=0;i<score.length;i++) score[i] = score[i]/sum; return score; } } }