package io.seldon.rpc; import java.io.IOException; import javax.annotation.PostConstruct; import org.apache.log4j.Logger; import org.datanucleus.util.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.ServerCall; import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptors; import io.grpc.StatusException; import io.grpc.stub.StreamObserver; import io.seldon.api.APIException; import io.seldon.api.Constants; import io.seldon.api.logging.PredictLogger; import io.seldon.api.resource.ConsumerBean; import io.seldon.api.rpc.ClassificationReply; import io.seldon.api.rpc.ClassificationRequest; import io.seldon.api.rpc.SeldonGrpc; import io.seldon.api.service.ResourceServer; import io.seldon.prediction.PredictionService; @Component public class ExternalRpcServer extends SeldonGrpc.SeldonImplBase implements ServerInterceptor { private static Logger logger = Logger.getLogger(ExternalRpcServer.class.getName()); private static final int port = 5000; private final Server server; private final PredictionService predictionService; @Autowired private ResourceServer resourceServer; @Autowired PredictLogger predictLogger; final Metadata.Key<String> authKey = Metadata.Key.of(Constants.OAUTH_TOKEN,Metadata.ASCII_STRING_MARSHALLER); ThreadLocal<String> clientThreadLocal = new ThreadLocal<String>(); public static class SeldonServerCallListener<R> extends ForwardingServerCallListener<R> { ServerCall.Listener<R> delegate; ExternalRpcServer server; String client; public SeldonServerCallListener(ServerCall.Listener<R> delegate,String client,ExternalRpcServer server) { this.delegate = delegate; this.server = server; this.client = client; } @Override protected Listener<R> delegate() { return delegate; } @Override public void onMessage(R request) { server.clientThreadLocal.set(client); super.onMessage(request); } } @Autowired public ExternalRpcServer(PredictionService predictionService) { logger.info("Initializing RPC server..."); this.predictionService = predictionService; ServerBuilder<?> serverBuilder = ServerBuilder.forPort(port); server = serverBuilder.addService(ServerInterceptors.intercept(this, this)).build(); } @PostConstruct public void startup(){ logger.info("Starting RPC server"); try { start(); } catch (IOException e) { logger.error("Failed to start RPC server ",e); } } @Override public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,ServerCallHandler<ReqT, RespT> next) { logger.info("Call intercepted "+headers.toString()); String token = headers.get(authKey); if (StringUtils.notEmpty(token)) { try { logger.info("Token "+token); ConsumerBean consumer = resourceServer.validateResourceFromToken(token); logger.info("Setting call to client "+consumer.getShort_name()); return new SeldonServerCallListener<ReqT>(next.startCall(call, headers),consumer.getShort_name(),this); } catch (APIException e) { logger.warn("API exception on getting token ",e); return next.startCall(call, headers); } } else { logger.warn("Empty token ignoring call"); return next.startCall(call, headers); } } @Override public void classify(ClassificationRequest request, StreamObserver<ClassificationReply> responseObserver) { final String client = clientThreadLocal.get(); if (StringUtils.notEmpty(client)) { clientThreadLocal.set(null); ClassificationReply reply = predictionService.predict(client, request); responseObserver.onNext(reply); responseObserver.onCompleted(); predictLogger.log(client, request, reply); } else { logger.info("Failed to get token"); responseObserver.onError(new StatusException(io.grpc.Status.PERMISSION_DENIED.withDescription("Could not determine client from oauth_token"))); } } /** Start serving requests. */ public void start() throws IOException { server.start(); logger.info("Server started"); Runtime.getRuntime().addShutdownHook(new Thread() { @Override public void run() { logger.info("Shutting down"); } }); } /** Stop serving requests and shutdown resources. */ public void stop() { if (server != null) { server.shutdown(); } } /** * Await termination on the main thread since the grpc library uses daemon threads. */ private void blockUntilShutdown() throws InterruptedException { if (server != null) { server.awaitTermination(); } } }