/* * Seldon -- open source prediction engine * ======================================= * * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * * ******************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * ******************************************************************************************** */ package io.seldon.api.state; import io.seldon.clustering.recommender.ItemRecommendationAlgorithm; import io.seldon.recommendation.AlgorithmStrategy; import io.seldon.recommendation.ClientStrategy; import io.seldon.recommendation.ItemFilter; import io.seldon.recommendation.ItemIncluder; import io.seldon.recommendation.JsOverrideClientStrategy; import io.seldon.recommendation.RecTagClientStrategy; import io.seldon.recommendation.SimpleClientStrategy; import io.seldon.recommendation.VariationTestingClientStrategy; import io.seldon.recommendation.combiner.AlgorithmResultsCombiner; import io.seldon.recommendation.filters.base.CurrentItemFilter; import io.seldon.recommendation.filters.base.IgnoredRecsFilter; import io.seldon.recommendation.filters.base.RecentImpressionsFilter; import java.io.IOException; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.PostConstruct; import org.apache.commons.lang.StringUtils; import org.apache.commons.lang3.BooleanUtils; import org.apache.log4j.Logger; import org.springframework.beans.BeansException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.stereotype.Component; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; import com.google.common.collect.Sets; /** * Cache of which algorithms to use for which clients. Receives updates via ClientConfigUpdateListener * * @author firemanphil * Date: 27/11/14 * Time: 13:55 */ @Component public class ClientAlgorithmStore implements ApplicationContextAware,ClientConfigUpdateListener,GlobalConfigUpdateListener { private static final String RECTAG = "alg_rectags"; protected static Logger logger = Logger.getLogger(ClientAlgorithmStore.class.getName()); private static final String ALG_KEY = "algs"; private static final String TESTING_SWITCH_KEY = "alg_test_switch"; private static final String TEST = "alg_test"; private final ClientConfigHandler configHandler; private final GlobalConfigHandler globalConfigHandler; private final Set<ItemFilter> alwaysOnFilters; private ApplicationContext applicationContext; private ConcurrentMap<String, ClientStrategy> store = new ConcurrentHashMap<>(); private ConcurrentMap<String, Map<String, AlgorithmStrategy>> storeMap = new ConcurrentHashMap<>(); private ConcurrentMap<String, Boolean> testingOnOff = new ConcurrentHashMap<>(); private ConcurrentMap<String, ClientStrategy> tests = new ConcurrentHashMap<>(); private ConcurrentMap<String, ClientStrategy> recTagStrategies = new ConcurrentHashMap<>(); private ClientStrategy defaultStrategy = null; private ConcurrentMap<String, ClientStrategy> namedStrategies = new ConcurrentHashMap<>(); @Autowired public ClientAlgorithmStore (ClientConfigHandler configHandler, GlobalConfigHandler globalConfigHandler, CurrentItemFilter currentItemFilter, IgnoredRecsFilter ignoredRecsFilter, RecentImpressionsFilter recentImpressionsFilter){ this.configHandler = configHandler; this.globalConfigHandler = globalConfigHandler; Set<ItemFilter> set = new HashSet<>(); set.add (currentItemFilter); set.add(ignoredRecsFilter); set.add(recentImpressionsFilter); alwaysOnFilters = Collections.unmodifiableSet(set); } @PostConstruct private void init(){ logger.info("Initializing..."); configHandler.addListener(this); globalConfigHandler.addSubscriber("default_strategy", this); // globalConfigHandler.addSubscriber("named_strategies", this); } public ClientStrategy retrieveStrategy(String client, Collection<String> algs){ ClientStrategy originalStrat = retrieveStrategy(client); return new JsOverrideClientStrategy(originalStrat, algs, applicationContext); } public ClientStrategy retrieveStrategy(String client){ if (testRunning(client)){ ClientStrategy strategy = tests.get(client); if(strategy!=null){ return strategy; } else { logger.warn("Testing was switch on for client " + client+ " but no test was specified." + " Returning default strategy"); return defaultStrategy; } } else { ClientStrategy strategy; if(recTagStrategies.containsKey(client)) strategy= recTagStrategies.get(client); else strategy = store.get(client); if(strategy!=null){ return strategy; } else { return defaultStrategy; } } } @Override public void configRemoved(String client, String configKey) { logger.info("Received config remove for "+client+" with key "+configKey); if (configKey.equals(ALG_KEY)){ store.remove(client); logger.info("Successfully removed "+client+" from "+ALG_KEY); } else if (configKey.equals(TESTING_SWITCH_KEY)){ testingOnOff.remove(client); logger.info("Successfully removed "+client+" from "+TESTING_SWITCH_KEY); } else if (configKey.equals(TEST)) { tests.remove(client); logger.info("Successfully removed "+client+" from "+TEST); } else if(configKey.equals(RECTAG)){ recTagStrategies.remove(client); logger.info("Successfully removed "+client+" from "+RECTAG); } else logger.warn("Ignored unknow config remove for "+client+" with key "+configKey); } @Override public void configUpdated(String client, String configKey, String configValue) { SimpleModule module = new SimpleModule("StrategyDeserializerModule"); module.addDeserializer(Strategy.class, new StrategyDeserializer()); ObjectMapper mapper = new ObjectMapper(); mapper.registerModule(module); if (configKey.equals(ALG_KEY)){ logger.info("Received new algorithm config for "+ client+": "+ configValue); try { List<AlgorithmStrategy> strategies = new ArrayList<>(); Map<String, AlgorithmStrategy> stratMap = new HashMap<>(); AlgorithmConfig config = mapper.readValue(configValue, AlgorithmConfig.class); for (Algorithm algorithm : config.algorithms) { AlgorithmStrategy strategy = toAlgorithmStrategy(algorithm); strategies.add(strategy); } AlgorithmResultsCombiner combiner = applicationContext.getBean( config.combiner,AlgorithmResultsCombiner.class); Map<Integer,Double> actionWeightMap = toActionWeightMap(config.actionWeights); store.put(client, new SimpleClientStrategy(Collections.unmodifiableList(strategies), combiner,config.diversityLevel, ClientStrategy.DEFAULT_NAME,actionWeightMap)); storeMap.put(client, Collections.unmodifiableMap(stratMap)); logger.info("Successfully added new algorithm config for "+client); } catch (IOException | BeansException e) { logger.error("Couldn't update algorithms for client " +client, e); } } else if (configKey.equals(TESTING_SWITCH_KEY)){ // not json as its so simple Boolean onOff = BooleanUtils.toBooleanObject(configValue); if(onOff==null){ logger.error("Couldn't set testing switch for client "+client +", input was " +configValue); } else { logger.info("Testing switch for client " + client + " moving from '" + BooleanUtils.toStringOnOff(testingOnOff.get(client)) + "' to '" + BooleanUtils.toStringOnOff(onOff)+"'"); testingOnOff.put(client, BooleanUtils.toBooleanObject(configValue)); } } else if (configKey.equals(TEST)) { logger.info("Received new testing config for " + client + ":" + configValue); try { TestConfig config = mapper.readValue(configValue, TestConfig.class); List<VariationTestingClientStrategy.Variation> variations = new ArrayList<>(); for (TestVariation var : config.variations){ List<AlgorithmStrategy> strategies = new ArrayList<>(); for (Algorithm alg : var.config.algorithms){ AlgorithmStrategy strategy = toAlgorithmStrategy(alg); strategies.add(strategy); } AlgorithmResultsCombiner combiner = applicationContext.getBean( var.config.combiner,AlgorithmResultsCombiner.class); Map<Integer,Double> actionWeightMap = toActionWeightMap(var.config.actionWeights); variations.add(new VariationTestingClientStrategy.Variation( new SimpleClientStrategy(Collections.unmodifiableList(strategies), combiner, var.config.diversityLevel, var.label,actionWeightMap), new BigDecimal(var.ratio))); } tests.put(client,VariationTestingClientStrategy.build(variations)); logger.info("Succesfully added " + variations.size() + " variation test for "+ client); } catch (NumberFormatException | IOException e) { logger.error("Couldn't add test for client " +client, e); } } else if(configKey.equals(RECTAG)){ logger.info("Received new rectag config for "+ client + ": " +configValue); try{ RecTagConfig config = mapper.readValue(configValue, RecTagConfig.class); if(config.defaultStrategy==null){ logger.error("Couldn't read rectag config as there was no default alg"); return; } ClientStrategy defStrategy = toStrategy(config.defaultStrategy); Map<String, ClientStrategy> recTagStrats = new HashMap<>(); for (Map.Entry<String, Strategy> entry : config.recTagToStrategy.entrySet() ){ recTagStrats.put(entry.getKey(),toStrategy(entry.getValue())); } recTagStrategies.put(client, new RecTagClientStrategy(defStrategy, recTagStrats)); logger.info("Successfully added rec tag strategy for " + client); } catch (NumberFormatException | IOException e) { logger.error("Couldn't add rectag strategy for client " +client, e); } } } private ClientStrategy toStrategy(Strategy jsonStrategy) { if (jsonStrategy instanceof TestConfig){ TestConfig jsonStrategyTest = (TestConfig) jsonStrategy; List<VariationTestingClientStrategy.Variation> variations = new ArrayList<>(); for (TestVariation var : jsonStrategyTest.variations){ List<AlgorithmStrategy> strategies = new ArrayList<>(); for (Algorithm alg : var.config.algorithms){ AlgorithmStrategy strategy = toAlgorithmStrategy(alg); strategies.add(strategy); } AlgorithmResultsCombiner combiner = applicationContext.getBean( var.config.combiner,AlgorithmResultsCombiner.class); Map<Integer,Double> actionWeightMap = toActionWeightMap(var.config.actionWeights); variations.add(new VariationTestingClientStrategy.Variation( new SimpleClientStrategy(Collections.unmodifiableList(strategies), combiner, var.config.diversityLevel, var.label,actionWeightMap), new BigDecimal(var.ratio))); } return VariationTestingClientStrategy.build(variations); } else { AlgorithmConfig jsonStrategyAlg = (AlgorithmConfig) jsonStrategy; List<AlgorithmStrategy> defaultAlgStrategies = new ArrayList<>(); for (Algorithm alg : jsonStrategyAlg.algorithms){ AlgorithmStrategy strategy = toAlgorithmStrategy(alg); defaultAlgStrategies.add(strategy); } AlgorithmResultsCombiner defCombiner = applicationContext.getBean( jsonStrategyAlg.combiner, AlgorithmResultsCombiner.class); Map<Integer,Double> defActionWeightMap = toActionWeightMap(jsonStrategyAlg.actionWeights); return new SimpleClientStrategy(defaultAlgStrategies, defCombiner, jsonStrategyAlg.diversityLevel,"-",defActionWeightMap); } } private AlgorithmStrategy toAlgorithmStrategy(Algorithm algorithm) { Set<ItemIncluder> includers = retrieveIncluders(algorithm.includers); Set<ItemFilter> filters = retrieveFilters(algorithm.filters); ItemRecommendationAlgorithm alg = applicationContext.getBean(algorithm.name, ItemRecommendationAlgorithm.class); Map<String, String> config = toConfigMap(algorithm.config); return new AlgorithmStrategy(alg, includers, filters, algorithm.config ==null ? new HashMap<String, String>(): config , algorithm.name); } private Map<String, String> toConfigMap(List<ConfigItem> config) { Map<String, String> configMap = new HashMap<>(); if (config==null) return configMap; for (ConfigItem item : config){ configMap.put(item.name,item.value); } return configMap; } private Map<Integer,Double> toActionWeightMap(List<ActionWeightItem> weights) { Map<Integer,Double> weightMap = new HashMap<Integer,Double>(); if (weights == null) return Collections.unmodifiableMap(weightMap); for(ActionWeightItem item : weights){ weightMap.put(Integer.parseInt(item.type), item.value); } return Collections.unmodifiableMap(weightMap); } private Set<ItemIncluder> retrieveIncluders(List<String> includers) { Set<ItemIncluder> includerSet = new HashSet<>(); if(includers==null) return includerSet; for (String includer : includers){ includerSet.add(applicationContext.getBean(includer, ItemIncluder.class)); } return includerSet; } private Set<ItemFilter> retrieveFilters(List<String> filters) { Set<ItemFilter> filterSet = new HashSet<>(); if(filters==null) return alwaysOnFilters; for (String filter : filters){ filterSet.add(applicationContext.getBean(filter, ItemFilter.class)); } return Sets.union(filterSet, alwaysOnFilters); } @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.applicationContext = applicationContext; StringBuilder builder= new StringBuilder("Available algorithms: \n"); for (ItemRecommendationAlgorithm inc : applicationContext.getBeansOfType(ItemRecommendationAlgorithm.class).values()){ builder.append('\t'); builder.append(inc.getClass()); builder.append('\n'); } logger.info(builder.toString()); builder= new StringBuilder("Available includers: \n"); for (ItemIncluder inc : applicationContext.getBeansOfType(ItemIncluder.class).values()){ builder.append('\t'); builder.append(inc.getClass()); builder.append('\n'); } logger.info(builder.toString()); builder = new StringBuilder("Available filters: \n" ); for (ItemFilter filt: applicationContext.getBeansOfType(ItemFilter.class).values()){ builder.append('\t'); builder.append(filt.getClass()); builder.append('\n'); } logger.info(builder.toString()); for (AlgorithmResultsCombiner filt: applicationContext.getBeansOfType(AlgorithmResultsCombiner.class).values()){ builder.append('\t'); builder.append(filt.getClass()); builder.append('\n'); } builder = new StringBuilder("Available combiners: \n" ); logger.info(builder.toString()); } private boolean testRunning(String client){ return testingOnOff.get(client)!=null && testingOnOff.get(client) && tests.get(client)!=null; } @Override public void configUpdated(String configKey, String configValue) { configValue = StringUtils.strip(configValue); logger.info("KEY WAS " + configKey); logger.info("Received new default strategy: " + configValue); if (StringUtils.length(configValue) == 0) { logger.warn("*WARNING* no default strategy is set!"); } else { try { ObjectMapper mapper = new ObjectMapper(); List<AlgorithmStrategy> strategies = new ArrayList<>(); AlgorithmConfig config = mapper.readValue(configValue, AlgorithmConfig.class); for (Algorithm alg : config.algorithms){ strategies.add(toAlgorithmStrategy(alg)); } AlgorithmResultsCombiner combiner = applicationContext.getBean( config.combiner,AlgorithmResultsCombiner.class); Map<Integer,Double> actionWeightMap = toActionWeightMap(config.actionWeights); ClientStrategy strat = new SimpleClientStrategy(strategies, combiner, config.diversityLevel,"-",actionWeightMap); defaultStrategy = strat; logger.info("Successfully changed default strategy."); } catch (IOException e){ logger.error("Problem changing default strategy ", e); } } } // classes for json translation public static class TestConfig extends Strategy { public List<TestVariation> variations; } public static class RecTagConfig { public Map<String, Strategy> recTagToStrategy; public Strategy defaultStrategy; } private static class TestVariation{ public String label; public String ratio; public AlgorithmConfig config; } public abstract static class Strategy {} public static class AlgorithmConfig extends Strategy { public List<Algorithm> algorithms; public String combiner; public Double diversityLevel; public List<ActionWeightItem> actionWeights; } public static class ConfigItem { public String name; public String value; } public static class ActionWeightItem { public String type; public Double value; } public static class Algorithm { public String name; public List<String> includers; public List<String> filters; public List<ConfigItem> config; } }