/*
* Seldon -- open source prediction engine
* =======================================
*
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
* ********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* ********************************************************************************************
*/
package io.seldon.mf;
import io.seldon.resources.external.ExternalResourceStreamer;
import io.seldon.resources.external.NewResourceNotifier;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import io.seldon.api.state.ClientAlgorithmStore;
import io.seldon.recommendation.model.ModelManager;
import io.seldon.resources.external.ExternalResourceStreamer;
import io.seldon.resources.external.NewResourceNotifier;
import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.InvalidMatrixException;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import javax.annotation.PostConstruct;
/**
*
* Manages matrix factorization models for recommendations. It loads new
* features files when sent notifications.
*
* @author firemanphil
* Date: 29/09/2014
* Time: 15:35
*/
@Component
public class MfFeaturesManager extends ModelManager<MfFeaturesManager.ClientMfFeaturesStore>{
private static Logger logger = Logger.getLogger(MfFeaturesManager.class.getName());
private final ExternalResourceStreamer featuresFileHandler;
private static final String MF_NEW_LOC_PATTERN = "mf";
@Autowired
public MfFeaturesManager(ExternalResourceStreamer featuresFileHandler,
NewResourceNotifier notifier){
super(notifier, Collections.singleton(MF_NEW_LOC_PATTERN));
this.featuresFileHandler = featuresFileHandler;
}
public ClientMfFeaturesStore loadModel(String location, String client){
logger.info("Reloading matrix factorization features for client: "+ client);
try {
BufferedReader userFeaturesReader = new BufferedReader(new InputStreamReader(
featuresFileHandler.getResourceStream(location + "/userFeatures.txt.gz")
));
Map<Long, float[]> userFeatures = readFeatures(userFeaturesReader);
int rank= 0;
if(!userFeatures.isEmpty()){
Long firstUser = userFeatures.keySet().iterator().next();
rank = userFeatures.get(firstUser).length;
}
BufferedReader productFeaturesReader = new BufferedReader(new InputStreamReader(
featuresFileHandler.getResourceStream(location + "/productFeatures.txt.gz")
));
Map<Long, float[]> productFeatures = readFeatures(productFeaturesReader);
logger.info("Finished loading MF features ("+userFeatures.size()+" users and "+productFeatures.size() +
" products at rank " + rank +") for " + client);
userFeaturesReader.close();
productFeaturesReader.close();
return new ClientMfFeaturesStore(userFeatures, productFeatures);
} catch (FileNotFoundException e) {
logger.error("Couldn't reloadFeatures for client "+ client, e);
} catch (IOException e) {
logger.error("Couldn't reloadFeatures for client "+ client, e);
}
return null;
}
private Map<Long,float[]> readFeatures(BufferedReader reader) throws IOException {
Map<Long, float[]> toReturn = new HashMap<>();
String line;
while((line = reader.readLine()) !=null){
String[] userAndFeatures = line.split("\\|");
Long item = Long.parseLong(userAndFeatures[0]);
String[] features = userAndFeatures[1].split(",");
float[] featuresList = new float[features.length];
for (int i = 0; i < featuresList.length; i++){
featuresList[i]= Float.parseFloat(features[i]);
}
toReturn.put(item, featuresList);
}
return toReturn;
}
//
// public void newClientLocation(String client, String location,String nodePattern) {
// reloadFeatures(location,client);
// }
//
// @Override
// public void clientLocationDeleted(String client,String nodePattern) {
// clientStores.remove(client);
// }
public static class ClientMfFeaturesStore {
public final Map<Long, float[]> userFeatures;
public final Map<Long, float[]> productFeatures;
public final double[][] productFeaturesInverse;
public final Map<Long,Integer> idMap;
public ClientMfFeaturesStore(Map<Long, float[]> userFeatures,
Map<Long, float[]> productFeatures){
this.userFeatures = userFeatures;
this.productFeatures = productFeatures;
int numProducts = productFeatures.size();
int numLatentFactors = productFeatures.values().iterator().next().length;
idMap = new HashMap<>();
double[][] itemFactorsDouble = new double[numProducts][numLatentFactors];
int i = 0;
for(Map.Entry<Long, float[]> e : productFeatures.entrySet())
{
idMap.put(e.getKey(), i);
for(int j=0;j<numLatentFactors;j++)
itemFactorsDouble[i][j] = e.getValue()[j];
i++;
}
productFeaturesInverse = computeUserFoldInMatrix(itemFactorsDouble);
if (productFeaturesInverse != null)
logger.info("Successfully created inverse of product feature matrix for fold in");
}
/**
* http://www.slideshare.net/fullscreen/srowen/matrix-factorization/16
* @param recentitemInteractions
* @param productFeaturesInverse
* @param idMap
* @return
*/
private double[][] computeUserFoldInMatrix(double[][] itemFactors)
{
try
{
RealMatrix Y = new Array2DRowRealMatrix(itemFactors);
RealMatrix YTY = Y.transpose().multiply(Y);
RealMatrix YTYInverse = new LUDecompositionImpl(YTY).getSolver().getInverse();
return Y.multiply(YTYInverse).getData();
}
catch (InvalidMatrixException e)
{
logger.warn("Failed to create inverse of products feature matrix",e);
return null;
}
}
}
}