/*
* Seldon -- open source prediction engine
* =======================================
*
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
* ********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* ********************************************************************************************
*/
package io.seldon.cc;
import io.seldon.api.resource.ConsumerBean;
import io.seldon.api.resource.service.ItemService;
import io.seldon.clustering.recommender.MemoryUserClusterStore;
import io.seldon.clustering.recommender.UserCluster;
import io.seldon.db.jdo.JDOFactory;
import io.seldon.mf.PerClientExternalLocationListener;
import io.seldon.recommendation.model.ModelManager;
import io.seldon.resources.external.ExternalResourceStreamer;
import io.seldon.resources.external.NewResourceNotifier;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import com.fasterxml.jackson.databind.ObjectMapper;
@Component
public class UserClusterManager implements PerClientExternalLocationListener {
private static Logger logger = Logger.getLogger(UserClusterManager.class.getName());
private final ConcurrentMap<String,MemoryUserClusterStore> clientStores = new ConcurrentHashMap<>();
private final ConcurrentMap<String,ClusterDescription> clusterDescriptions = new ConcurrentHashMap<>();
private Set<NewResourceNotifier> notifiers = new HashSet<>();
private final ExternalResourceStreamer featuresFileHandler;
public static final String CLUSTER_NEW_LOC_PATTERN = "userclusters";
private static UserClusterManager theManager; // hack until rest of code Springified
//private final Executor executor = Executors.newFixedThreadPool(5);
private BlockingQueue<Runnable> queue = new LinkedBlockingDeque<>();
private ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 5, 10, TimeUnit.MINUTES, queue) {
protected void afterExecute(java.lang.Runnable runnable, java.lang.Throwable throwable)
{
JDOFactory.get().cleanupPM();
}
};
private ItemService itemService;
@Autowired
public UserClusterManager(ExternalResourceStreamer featuresFileHandler,NewResourceNotifier notifier,ItemService itemService){
this.featuresFileHandler = featuresFileHandler;
notifiers.add(notifier);
notifier.addListener(CLUSTER_NEW_LOC_PATTERN, this);
this.theManager = this;
this.itemService = itemService;
}
public static UserClusterManager get()
{
return theManager;
}
public void reloadFeatures(final String location, final String client){
executor.execute(new Runnable() {
@Override
public void run() {
logger.info("Reloading user clusters for client: "+ client);
try {
BufferedReader reader = new BufferedReader(new InputStreamReader(
featuresFileHandler.getResourceStream(location + "/part-00000")
));
MemoryUserClusterStore userClusters = loadUserClusters(client, reader);
clientStores.put(client, userClusters);
reader.close();
logger.info("finished load of user clusters for client "+client);
}
catch (FileNotFoundException e) {
logger.error("Couldn't reloadFeatures for client "+ client, e);
} catch (IOException e) {
logger.error("Couldn't reloadFeatures for client "+ client, e);
}
}
});
}
protected MemoryUserClusterStore loadUserClusters(String client,BufferedReader reader) throws IOException
{
String line;
List<UserCluster> clusters = new ArrayList<>();
ObjectMapper mapper = new ObjectMapper();
int numUsers = 0;
int numClusters = 0;
long lastUser = -1;
Set<Integer> dimensions = new HashSet<Integer>();
while((line = reader.readLine()) !=null)
{
UserDimWeight data = mapper.readValue(line.getBytes(), UserDimWeight.class);
if (lastUser != data.user)
numUsers++;
clusters.add(new UserCluster(data.user, data.dim, data.weight, 0, 0));
dimensions.add(data.dim);
lastUser = data.user;
numClusters++;
}
MemoryUserClusterStore store = new MemoryUserClusterStore(client,numUsers);
storeClusters(store,clusters);
store.setLoaded(true);
setClusterDescription(client, dimensions);
logger.info("Loaded user clusters for client "+client+" with num users "+numUsers+" and number of clusters "+numClusters);
return store;
}
private void setClusterDescription(String client,Set<Integer> dimensions)
{
ConsumerBean c = new ConsumerBean(client);
Map<Integer,String> clusterNames = new HashMap<>();
try
{
for(Integer dim : dimensions)
{
String[] names = itemService.getDimensionName(c, dim);
if (names != null && names.length == 2)
{
clusterNames.put(dim, names[0]+":"+names[1]);
}
else
logger.warn("Can't find cluster name in db for dimension "+dim+" for "+client);
}
}
catch (Exception e)
{
logger.error("Failed to create cluster descriptions for "+client,e);
}
clusterDescriptions.put(client, new ClusterDescription(clusterNames));
}
private void storeClusters(MemoryUserClusterStore store,List<UserCluster> clusters)
{
long currentUser = -1;
List<UserCluster> userClusters = new ArrayList<>();
for(UserCluster cluster : clusters)
{
if (currentUser != -1 && currentUser != cluster.getUser())
{
store.store(currentUser, userClusters);
userClusters = new ArrayList<>();
}
userClusters.add(cluster);
currentUser = cluster.getUser();
}
if (userClusters.size() > 0)
store.store(currentUser, userClusters);
}
public MemoryUserClusterStore getStore(String client)
{
return clientStores.get(client);
}
public ClusterDescription getClusterDescriptions(String client)
{
return clusterDescriptions.get(client);
}
@Override
public void newClientLocation(String client, String location,
String nodePattern) {
reloadFeatures(location, client);
}
@Override
public void clientLocationDeleted(String client, String nodePattern) {
clientStores.remove(client);
}
public static class ClusterDescription {
public final Map<Integer,String> clusterNames;
public ClusterDescription(Map<Integer, String> clusterNames) {
super();
this.clusterNames = clusterNames;
}
}
}