/* * Seldon -- open source prediction engine * ======================================= * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * ********************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ********************************************************************************************** */ package io.seldon.api.state; import io.seldon.prediction.FeatureTransformer; import io.seldon.prediction.FeatureTransformerStrategy; import io.seldon.prediction.PredictionAlgorithm; import io.seldon.prediction.PredictionAlgorithmStrategy; import io.seldon.prediction.PredictionStrategy; import io.seldon.prediction.SimplePredictionStrategy; import io.seldon.prediction.VariationPredictionStrategy; import java.io.IOException; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.PostConstruct; import org.apache.commons.lang.StringUtils; import org.apache.log4j.Logger; import org.springframework.beans.BeansException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.stereotype.Component; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; @Component public class PredictionAlgorithmStore implements ApplicationContextAware,ClientConfigUpdateListener,GlobalConfigUpdateListener { protected static Logger logger = Logger.getLogger(PredictionAlgorithmStore.class.getName()); public static final String ALG_KEY = "predict_algs"; private ConcurrentMap<String, PredictionStrategy> predictionStore = new ConcurrentHashMap<>(); private ConcurrentMap<String, FeatureTransformerStrategy> transformerStore = new ConcurrentHashMap<>(); private PredictionStrategy defaultStrategy; private final ClientConfigHandler configHandler; private final GlobalConfigHandler globalConfigHandler; private ApplicationContext applicationContext; @Autowired public PredictionAlgorithmStore(ClientConfigHandler configHandler, GlobalConfigHandler globalConfigHandler) { this.configHandler = configHandler; this.globalConfigHandler = globalConfigHandler; } @PostConstruct private void init(){ logger.info("Initializing..."); configHandler.addListener(this); globalConfigHandler.addSubscriber("default_prediction_strategy", this); } public PredictionStrategy retrieveStrategy(String client) { PredictionStrategy strategy = predictionStore.get(client); if(strategy!=null){ return strategy; } else { return defaultStrategy; } } @Override public void configUpdated(String configKey, String configValue) { configValue = StringUtils.strip(configValue); logger.info("KEY WAS " + configKey); logger.info("Received new default strategy: " + configValue); if (StringUtils.length(configValue) == 0) { logger.warn("*WARNING* no default strategy is set!"); } else { try { ObjectMapper mapper = new ObjectMapper(); List<PredictionAlgorithmStrategy> strategies = new ArrayList<>(); AlgorithmConfig config = mapper.readValue(configValue, AlgorithmConfig.class); for (Algorithm algorithm : config.algorithms) { PredictionAlgorithmStrategy strategy = toAlgorithmStrategy(algorithm); strategies.add(strategy); } List<FeatureTransformerStrategy> featureTransformerStrategies = new ArrayList<>(); for (Transformer transformer : config.transformers) { FeatureTransformerStrategy strategy = toFeatureTransformerStrategy(transformer); featureTransformerStrategies.add(strategy); } defaultStrategy = new SimplePredictionStrategy(PredictionStrategy.DEFAULT_NAME,Collections.unmodifiableList(featureTransformerStrategies),Collections.unmodifiableList(strategies)); logger.info("Successfully added new default prediction strategy"); } catch (IOException | BeansException e) { logger.error("Couldn't update default prediction strategy", e); } } } @Override public void configUpdated(String client, String configKey,String configValue) { SimpleModule module = new SimpleModule("PredictionStrategyDeserializerModule"); module.addDeserializer(Strategy.class, new PredictionStrategyDeserializer()); ObjectMapper mapper = new ObjectMapper(); mapper.registerModule(module); if (configKey.equals(ALG_KEY)){ logger.info("Received new algorithm config for "+ client+": "+ configValue); try { Strategy configStrategy = mapper.readValue(configValue, Strategy.class); if (configStrategy instanceof AlgorithmConfig) { List<PredictionAlgorithmStrategy> strategies = new ArrayList<>(); AlgorithmConfig config = (AlgorithmConfig) configStrategy; for (Algorithm algorithm : config.algorithms) { PredictionAlgorithmStrategy strategy = toAlgorithmStrategy(algorithm); strategies.add(strategy); } List<FeatureTransformerStrategy> featureTransformerStrategies = new ArrayList<>(); if (config.transformers != null) for (Transformer transformer : config.transformers) { FeatureTransformerStrategy strategy = toFeatureTransformerStrategy(transformer); featureTransformerStrategies.add(strategy); } predictionStore.put(client, new SimplePredictionStrategy(PredictionStrategy.DEFAULT_NAME,Collections.unmodifiableList(featureTransformerStrategies),Collections.unmodifiableList(strategies))); logger.info("Successfully added new algorithm config for "+client); } else if (configStrategy instanceof TestConfig) { TestConfig config = (TestConfig) configStrategy; //TestConfig config = mapper.readValue(configValue, TestConfig.class); List<VariationPredictionStrategy.Variation> variations = new ArrayList<>(); for (TestVariation var : config.variations){ List<PredictionAlgorithmStrategy> strategies = new ArrayList<>(); for (Algorithm alg : var.config.algorithms){ PredictionAlgorithmStrategy strategy = toAlgorithmStrategy(alg); strategies.add(strategy); } List<FeatureTransformerStrategy> featureTransformerStrategies = new ArrayList<>(); if (var.config.transformers != null) for (Transformer transformer : var.config.transformers) { FeatureTransformerStrategy strategy = toFeatureTransformerStrategy(transformer); featureTransformerStrategies.add(strategy); } //Need to add combiner here variations.add(new VariationPredictionStrategy.Variation( new SimplePredictionStrategy(var.label,Collections.unmodifiableList(featureTransformerStrategies),Collections.unmodifiableList(strategies)), new BigDecimal(var.ratio))); } predictionStore.put(client,VariationPredictionStrategy.build(variations)); logger.info("Succesfully added " + variations.size() + " variation test for "+ client); } else { logger.error("Unknown type for algorithm config"); } } catch (IOException | BeansException e) { logger.error("Couldn't update algorithms for client " +client, e); } } } @Override public void configRemoved(String client, String configKey) { if (configKey.equals(ALG_KEY)){ predictionStore.remove(client); logger.info("Removed client "+client+" from "+ALG_KEY); } } @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { this.applicationContext = applicationContext; StringBuilder builder= new StringBuilder("Available algorithms: \n"); for (PredictionAlgorithm inc : applicationContext.getBeansOfType(PredictionAlgorithm.class).values()){ builder.append('\t'); builder.append(inc.getClass()); builder.append('\n'); } builder.append("Available feature transformers: \n"); for(FeatureTransformer f : applicationContext.getBeansOfType(FeatureTransformer.class).values()){ builder.append('\t'); builder.append(f.getClass()); builder.append('\n'); } logger.info(builder.toString()); } private PredictionAlgorithmStrategy toAlgorithmStrategy(Algorithm algorithm) { PredictionAlgorithm alg = applicationContext.getBean(algorithm.name, PredictionAlgorithm.class); Map<String, String> config = toConfigMap(algorithm.config); return new PredictionAlgorithmStrategy(alg,algorithm.config ==null ? new HashMap<String, String>(): config , algorithm.name); } private FeatureTransformerStrategy toFeatureTransformerStrategy(Transformer transformer) { FeatureTransformer t = applicationContext.getBean(transformer.name,FeatureTransformer.class); Map<String, String> config = toConfigMap(transformer.config); return new FeatureTransformerStrategy(t, transformer.inputCols, transformer.outputCols, config); } private Map<String, String> toConfigMap(List<ConfigItem> config) { Map<String, String> configMap = new HashMap<>(); if (config==null) return configMap; for (ConfigItem item : config){ configMap.put(item.name,item.value); } return configMap; } public abstract static class Strategy {} // classes for json translation public static class TestConfig extends Strategy { public List<TestVariation> variations; } public static class TestVariation{ public String label; public String ratio; public AlgorithmConfig config; } public static class AlgorithmConfig extends Strategy { public List<Algorithm> algorithms; public List<Transformer> transformers; public String combiner; } public static class ConfigItem { public String name; public String value; } public static class Algorithm { public String name; public List<ConfigItem> config; } public static class Transformer { public String name; public List<String> inputCols; public List<String> outputCols; public List<ConfigItem> config; } }