package io.seldon.rpc; import java.io.IOException; import java.util.concurrent.TimeUnit; import org.apache.log4j.Logger; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectReader; import io.grpc.ManagedChannel; import io.seldon.api.rpc.ClassificationReply; import io.seldon.api.rpc.ClassificationRequest; import io.seldon.api.rpc.SeldonGrpc; import io.seldon.api.rpc.SeldonGrpc.SeldonBlockingStub; import io.seldon.clustering.recommender.RecommendationContext.OptionsHolder; import io.seldon.prediction.PredictionAlgorithm; import io.seldon.prediction.PredictionServiceResult; @Component public class RpcPredictionServer implements PredictionAlgorithm { private static Logger logger = Logger.getLogger(RpcPredictionServer.class.getName()); private static final String name = RpcPredictionServer.class.getName(); private static final String HOST_PROPERTY_NAME="io.seldon.rpc.microservice.host"; private static final String PORT_PROPERTY_NAME="io.seldon.rpc.microservice.port"; final ClientRpcStore rpcStore; RpcChannelHandler channelHandler; @Autowired public RpcPredictionServer(ClientRpcStore rpcStore,RpcChannelHandler channelHandler){ this.rpcStore = rpcStore; this.channelHandler = channelHandler; } public String getName() { return name; } @Override public PredictionServiceResult predictFromJSON(String client, JsonNode json, OptionsHolder options) { try { ClassificationRequest request = rpcStore.getPredictRequestFromJson(client, json); ClassificationReply reply = predictFromProto(client, request, options); JsonNode actualObj = rpcStore.getJSONForReply(client, reply); PredictionServiceResult res = null; ObjectMapper mapper = new ObjectMapper(); ObjectReader reader = mapper.reader(PredictionServiceResult.class); res = reader.readValue(actualObj); return res; } catch (JsonProcessingException e) { logger.error("Couldn't retrieve prediction from external prediction server - ", e); return null; } catch (IOException e) { logger.error("Couldn't retrieve prediction from external prediction server - ", e); return null; } catch (Exception e) { logger.error("Couldn't retrieve prediction from external prediction server - ", e); return null; } finally{} } @Override public ClassificationReply predictFromProto(String client, ClassificationRequest request, OptionsHolder options) { ManagedChannel channel = channelHandler.getChannel(client,options.getStringOption(HOST_PROPERTY_NAME), options.getIntegerOption(PORT_PROPERTY_NAME)); SeldonBlockingStub stub = SeldonGrpc.newBlockingStub(channel).withDeadlineAfter(5, TimeUnit.SECONDS); long t1 = System.currentTimeMillis(); logger.info("call start"); ClassificationReply reply = stub.classify(request); long t2 = System.currentTimeMillis(); long duration = t2-t1; logger.info("call end "+duration); return reply; } }