/*
* Seldon -- open source prediction engine
* =======================================
*
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
* ********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* ********************************************************************************************
*/
package io.seldon.recommendation.model;
import io.seldon.clustering.recommender.RecommendationContext;
import io.seldon.mf.PerClientExternalLocationListener;
import io.seldon.recommendation.ClientStrategy;
import io.seldon.resources.external.NewResourceNotifier;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import org.apache.log4j.Logger;
/**
* @author firemanphil
* Date: 28/04/15
* Time: 12:16
*/
public abstract class ModelManager<T> implements PerClientExternalLocationListener {
private static final String MODEL_PROPERTY_NAME = "io.seldon.algorithm.model.name";
private static Logger logger = Logger.getLogger(ModelManager.class.getName());
private final ConcurrentMap<String, ConcurrentMap<String,T>> clientStores
= new ConcurrentHashMap<>();
private final Executor executor;
private final Set<String> nodeBases;
public ModelManager(NewResourceNotifier notifier, Set<String> nodePatterns) {
this(notifier, nodePatterns, Executors.newFixedThreadPool(2));
}
public ModelManager(NewResourceNotifier notifier, Set<String> nodePatterns, Executor executor) {
this.nodeBases = nodePatterns;
for (String pattern : nodePatterns) {
notifier.addListener(pattern, this);
}
this.executor = executor;
}
@Override
public void newClientLocation(final String client, final String location, final String nodePattern) {
logger.info("New location "+client+" : "+location+ " : "+nodePattern);
String rightBase = null;
Iterator<String> iter = nodeBases.iterator();
while(rightBase==null && iter.hasNext()) {
String base = iter.next();
if (nodePattern.contains(base))
rightBase = base;
}
final String finalPartOfNode = nodePattern.replace(rightBase,"").replaceFirst("/", "");
executor.execute(new Runnable() {
@Override
public void run() {
String key = getKey(client,nodePattern);
logger.info("Loading with client:"+client+" location:"+location+" key:"+key+" finalPartOfNode:"+finalPartOfNode);
T result = loadModel(location, client);
logger.info("Loaded with client:"+client+" location:"+location+" key:"+key+"finalPartOfNodeL"+finalPartOfNode+" result:"+result);
clientStores.putIfAbsent(key, new ConcurrentHashMap<String, T>());
clientStores.get(key).put(finalPartOfNode, result);
for (Map<String, T> store : clientStores.values()) {
for (String t : store.keySet()) {
logger.info(t + " " + store.get(t));
}
}
}
});
}
public T getClientStore(String client, RecommendationContext.OptionsHolder options){
String type = nodeBases.iterator().next();
return getClientStore(client, type, options);
}
public T getClientStore(String client, String type, RecommendationContext.OptionsHolder options){
String modelName = options.getStringOption(MODEL_PROPERTY_NAME);
String key = getKey(client, type);
if (logger.isDebugEnabled())
logger.debug("Get client store for client "+client+" type "+type+" modelName "+modelName+" key:"+key);
if (!clientStores.containsKey(key))
{
if (logger.isDebugEnabled())
logger.debug("Failed to find store with key:"+key+" for client "+client);
return null;
}
// check whether we are testing or not and get relevant model.
switch (modelName) {
case ClientStrategy.DEFAULT_NAME:
logger.debug("Returning default store for client "+modelName);
return clientStores.get(key).get("");
default:
T store = clientStores.get(key).get(modelName);
if (store == null) {
logger.warn("Couldn't find model under name " + modelName + " for client " + client);
return clientStores.get(key).get("");
} else {
return store;
}
}
}
@Override
public void clientLocationDeleted(String client, String nodePattern) {
String rightBase = null;
Iterator<String> iter = nodeBases.iterator();
while(rightBase==null && iter.hasNext()) {
String base = iter.next();
if (nodePattern.contains(base))
rightBase = base;
}
String key = getKey(client, rightBase);
final String finalPartOfNode = nodePattern.replace(rightBase, "").replaceFirst("/", "");
if(clientStores.get(key)!=null){
clientStores.get(key).remove(finalPartOfNode);
}
}
protected abstract T loadModel(String location,String client);
private String getKey(String client,String key)
{
return client + ":" + key;
}
}