/* * Seldon -- open source prediction engine * ======================================= * Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/) * ********************************************************************************************** * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ********************************************************************************************** */ package io.seldon.prediction; import java.util.ArrayList; import org.apache.commons.lang3.StringUtils; import org.apache.log4j.Logger; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import com.fasterxml.jackson.databind.JsonNode; import io.seldon.api.APIException; import io.seldon.api.logging.PredictLogger; import io.seldon.api.rpc.ClassificationReply; import io.seldon.api.rpc.ClassificationReplyMeta; import io.seldon.api.rpc.ClassificationRequest; import io.seldon.api.state.PredictionAlgorithmStore; import io.seldon.api.state.options.DefaultOptions; import io.seldon.clustering.recommender.RecommendationContext.OptionsHolder; import io.seldon.memcache.SecurityHashPeer; @Component public class PredictionService { private static Logger logger = Logger.getLogger(PredictionService.class.getName()); private final DefaultOptions defaultOptions; private PredictionAlgorithmStore algStore; @Autowired public PredictionService(PredictionAlgorithmStore algStore, DefaultOptions defaultOptions) { this.algStore = algStore; this.defaultOptions = defaultOptions; } public ClassificationReply predict(String client,ClassificationRequest request) { PredictionStrategy strategyTop = algStore.retrieveStrategy(client); if (strategyTop == null) { throw new APIException(APIException.NOT_VALID_STRATEGY); } SimplePredictionStrategy strategy = strategyTop.configure(); // apply prediction algorithm(s) for(PredictionAlgorithmStrategy algStr : strategy.getAlgorithms()) { OptionsHolder optsHolder = new OptionsHolder(defaultOptions, algStr.config); ClassificationReply res = algStr.algorithm.predictFromProto(client, request, optsHolder); if (res != null && res.getPredictionsList() != null && res.getPredictionsCount() > 0) { if (res.getMeta() == null) { ClassificationReplyMeta meta = ClassificationReplyMeta.newBuilder().setVariation(strategy.label).setPuid(SecurityHashPeer.getNewId()).setModelName(algStr.name).build(); return ClassificationReply.newBuilder().setCustom(res.getCustom()).setMeta(meta).addAllPredictions(res.getPredictionsList()).build(); } else { ClassificationReplyMeta.Builder metaBuilder = ClassificationReplyMeta.newBuilder(); if (!request.hasMeta() || StringUtils.isEmpty(request.getMeta().getPuid())) metaBuilder.setPuid(SecurityHashPeer.getNewId()); else metaBuilder.setPuid(request.getMeta().getPuid()); ClassificationReplyMeta meta = metaBuilder.setVariation(strategy.label).setModelName(res.getMeta().getModelName()).build(); if (res.hasCustom()) return ClassificationReply.newBuilder().setCustom(res.getCustom()).setMeta(meta).addAllPredictions(res.getPredictionsList()).build(); else return ClassificationReply.newBuilder().setMeta(meta).addAllPredictions(res.getPredictionsList()).build(); } } } logger.warn("No prediction for client "+client); return ClassificationReply.newBuilder().build(); } public PredictionServiceResult predict(String client,String puid, JsonNode json) { PredictionStrategy strategyTop = algStore.retrieveStrategy(client); if (strategyTop == null) { throw new APIException(APIException.NOT_VALID_STRATEGY); } SimplePredictionStrategy strategy = strategyTop.configure(); // transform features //for(FeatureTransformerStrategy transStr : strategy.getFeatureTansformers()) //{ // json = transStr.transformer.transform(client, json, transStr); //} if (puid == null) puid = SecurityHashPeer.getNewId(); // apply prediction algorithm(s) for(PredictionAlgorithmStrategy algStr : strategy.getAlgorithms()) { OptionsHolder optsHolder = new OptionsHolder(defaultOptions, algStr.config); PredictionServiceResult predictionServiceResult = algStr.algorithm.predictFromJSON(client, json, optsHolder); if (predictionServiceResult != null && predictionServiceResult.predictions != null && predictionServiceResult.predictions.size() > 0) { if (predictionServiceResult.meta == null) { PredictionMetadata meta = new PredictionMetadata(algStr.name, strategy.label, puid); predictionServiceResult.meta = meta; } else { predictionServiceResult.meta.setPuid(puid); predictionServiceResult.meta.variation = strategy.label; } return predictionServiceResult; } } logger.warn("No prediction for client "+client+" with json "+json); PredictionMetadata meta = new PredictionMetadata("", strategy.label, puid); PredictionServiceResult res = new PredictionServiceResult(meta,new ArrayList<PredictionResult>(),null); return res; } }