/*
* Seldon -- open source prediction engine
* =======================================
*
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
* ********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* ********************************************************************************************
*/
package io.seldon.mf;
import io.seldon.clustering.recommender.ItemRecommendationAlgorithm;
import io.seldon.clustering.recommender.ItemRecommendationResultSet;
import io.seldon.clustering.recommender.ItemRecommendationResultSet.ItemRecommendationResult;
import io.seldon.clustering.recommender.RecommendationContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.RealVector;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import com.google.common.collect.Ordering;
@Component
public class RecentMfRecommender implements ItemRecommendationAlgorithm {
private static Logger logger = Logger.getLogger(RecentMfRecommender.class.getName());
private static final String name = RecentMfRecommender.class.getSimpleName();
private static final String RECENT_ACTIONS_PROPERTY_NAME = "io.seldon.algorithm.general.numrecentactionstouse";
private final MfFeaturesManager store;
@Autowired
public RecentMfRecommender(MfFeaturesManager store){
this.store = store;
}
@Override
public ItemRecommendationResultSet recommend(String client, Long user, Set<Integer> dimensions,
int maxRecsCount, RecommendationContext ctxt, List<Long> recentItemInteractions) {
RecommendationContext.OptionsHolder opts = ctxt.getOptsHolder();
int numRecentActionsToUse = opts.getIntegerOption(RECENT_ACTIONS_PROPERTY_NAME);
MfFeaturesManager.ClientMfFeaturesStore clientStore = this.store.getClientStore(client, ctxt.getOptsHolder());
if(clientStore==null) {
logger.debug("Couldn't find a matrix factorization store for this client");
return new ItemRecommendationResultSet(Collections.<ItemRecommendationResult>emptyList(), name);
}
List<Long> itemsToScore;
if(recentItemInteractions.size() > numRecentActionsToUse)
{
if (logger.isDebugEnabled())
logger.debug("Limiting recent items for score to size "+numRecentActionsToUse+" from present "+recentItemInteractions.size());
itemsToScore = recentItemInteractions.subList(0, numRecentActionsToUse);
}
else
itemsToScore = new ArrayList<>(recentItemInteractions);
if (logger.isDebugEnabled())
logger.debug("Recent items of size "+itemsToScore.size()+" -> "+itemsToScore.toString());
double[] userVector;
if (clientStore.productFeaturesInverse != null)
{
//fold in user data from their recent history of item interactions
logger.debug("Creating user vector by folding in features");
userVector = foldInUser(itemsToScore, clientStore.productFeaturesInverse, clientStore.idMap);
}
else
{
logger.debug("Creating user vector by averaging features");
userVector = createAvgProductVector(itemsToScore, clientStore.productFeatures);
}
Set<ItemRecommendationResult> recs = new HashSet<>();
if(ctxt.getMode()== RecommendationContext.MODE.INCLUSION){
// special case for INCLUSION as it's easier on the cpu.
for (Long item : ctxt.getContextItems()){
if (!recentItemInteractions.contains(item))
{
float[] features = clientStore.productFeatures.get(item);
if(features!=null)
recs.add(new ItemRecommendationResult(item, dot(features,userVector)));
}
}
} else {
for (Map.Entry<Long, float[]> productFeatures : clientStore.productFeatures.entrySet()) {
Long item = productFeatures.getKey().longValue();
if (!recentItemInteractions.contains(item))
{
recs.add(new ItemRecommendationResult(item,dot(productFeatures.getValue(),userVector)));
}
}
}
List<ItemRecommendationResult> recsList = Ordering.natural().greatestOf(recs, maxRecsCount);
if (logger.isDebugEnabled())
logger.debug("Created "+recsList.size() + " recs");
return new ItemRecommendationResultSet(recsList, name);
}
public double[] createAvgProductVector(List<Long> recentitemInteractions,Map<Long,float[]> productFeatures)
{
int numLatentFactors = productFeatures.values().iterator().next().length;
double[] userFeatures = new double[numLatentFactors];
for (Long item : recentitemInteractions)
{
float[] productFactors = productFeatures.get(item);
if (productFactors != null)
{
for (int feature = 0; feature < numLatentFactors; feature++)
{
userFeatures[feature] += productFactors[feature];
}
}
}
RealVector userFeaturesAsVector = new ArrayRealVector(userFeatures);
RealVector normalised = userFeaturesAsVector.mapDivide(userFeaturesAsVector.getL1Norm());
return normalised.getData();
}
/**
* http://www.slideshare.net/fullscreen/srowen/matrix-factorization/16
* @param recentitemInteractions
* @param productFeaturesInverse
* @param idMap
* @return
*/
public double[] foldInUser(List<Long> recentitemInteractions,double[][] productFeaturesInverse,Map<Long,Integer> idMap) {
int numLatentFactors = productFeaturesInverse[0].length;
double[] userFeatures = new double[numLatentFactors];
for (Long item : recentitemInteractions)
{
Integer id = idMap.get(item);
if (id != null)
{
for (int feature = 0; feature < numLatentFactors; feature++)
{
userFeatures[feature] += productFeaturesInverse[id][feature];
}
}
}
return userFeatures;
}
private static float dot(float[] vec1, double[] vec2){
float sum = 0;
for (int i = 0; i < vec1.length; i++){
sum += vec1[i] * vec2[i];
}
return sum;
}
@Override
public String name() {
return name;
}
}