package io.seldon.rpc; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.util.concurrent.ConcurrentHashMap; import org.apache.log4j.Logger; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.util.JsonFormat; import com.google.protobuf.util.JsonFormat.TypeRegistry; import io.seldon.api.resource.service.business.PredictionBusinessServiceImpl; import io.seldon.api.rpc.ClassificationReply; import io.seldon.api.rpc.ClassificationRequest; import io.seldon.api.rpc.DefaultCustomPredictRequest; import io.seldon.api.state.ClientConfigHandler; import io.seldon.api.state.ClientConfigUpdateListener; @Component public class ClientRpcStore implements ClientConfigUpdateListener { private static Logger logger = Logger.getLogger(ClientRpcStore.class.getName()); public static final String RPC_KEY = "rpc"; ConcurrentHashMap<String,RPCConfig> services = new ConcurrentHashMap<String, ClientRpcStore.RPCConfig>(); @Autowired public ClientRpcStore(ClientConfigHandler configHandler) { logger.info("starting up"); configHandler.addListener(this); } public RPCConfig getRPCConfig(String client) { if (services.containsKey(client)) { return services.get(client); } else return null; } private JsonNode getJSONFromMethod(Method m,Message msg,String fieldname) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException, JsonParseException, IOException { Message.Builder o2 = (Message.Builder) m.invoke(null); TypeRegistry registry = TypeRegistry.newBuilder().add(o2.getDescriptorForType()).build(); JsonFormat.Printer jPrinter = JsonFormat.printer(); String result = jPrinter.usingTypeRegistry(registry).print(msg); ObjectMapper mapper = new ObjectMapper(); JsonFactory factory = mapper.getFactory(); JsonParser parser = factory.createParser(result); JsonNode jNode = mapper.readTree(parser); if (jNode.has(fieldname) && jNode.get(fieldname).has("@type")) ((ObjectNode) jNode.get(fieldname)).remove("@type"); return jNode; } private JsonNode getJSON(Message msg,String fieldname) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException, JsonParseException, IOException { JsonFormat.Printer jPrinter = JsonFormat.printer(); String result = jPrinter.print(msg); ObjectMapper mapper = new ObjectMapper(); JsonFactory factory = mapper.getFactory(); JsonParser parser = factory.createParser(result); JsonNode jNode = mapper.readTree(parser); return jNode; } private JsonNode getDefaultRequestJSON(Message msg) throws JsonParseException, IOException { Message.Builder o2 = DefaultCustomPredictRequest.newBuilder(); TypeRegistry registry = TypeRegistry.newBuilder().add(o2.getDescriptorForType()).build(); JsonFormat.Printer jPrinter = JsonFormat.printer(); String result = jPrinter.usingTypeRegistry(registry).print(msg); ObjectMapper mapper = new ObjectMapper(); JsonFactory factory = mapper.getFactory(); JsonParser parser = factory.createParser(result); JsonNode jNode = mapper.readTree(parser); if (jNode.has(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD)) { JsonNode values = jNode.get(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD).get("values"); ((ObjectNode) jNode).set(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD, values); } return jNode; } public JsonNode getJSONForRequest(String client,ClassificationRequest request) { RPCConfig config = services.get(client); if (config != null) { try { if (config.requestClass != null) { Method m = config.requestBuilder; return getJSONFromMethod(m, request, PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD); } else return getDefaultRequestJSON(request); } catch (Exception e) { logger.error("Failed to create JSON request for client "+client,e); return null; } } else { try { return getDefaultRequestJSON(request); } catch (Exception e) { logger.error("Failed to create JSON request for client from default "+client,e); return null; } } } public JsonNode getJSONForReply(String client,ClassificationReply request) { RPCConfig config = services.get(client); try { if (config != null && config.replyClass != null) { Method m = config.replyBuilder; return getJSONFromMethod(m, request, PredictionBusinessServiceImpl.REPLY_CUSTOM_DATA_FIELD); } else return getJSON(request, PredictionBusinessServiceImpl.REPLY_CUSTOM_DATA_FIELD); } catch (Exception e) { logger.error("Failed to create JSON reply for client "+client,e); return null; } } public ClassificationReply getPredictReplyFromJson(String client,JsonNode json) { RPCConfig config = services.get(client); try { TypeRegistry registry = null; if (config != null && config.replyClass != null && json.has(PredictionBusinessServiceImpl.REPLY_CUSTOM_DATA_FIELD)) { if (!json.get(PredictionBusinessServiceImpl.REPLY_CUSTOM_DATA_FIELD).has("@type")) ((ObjectNode) json.get(PredictionBusinessServiceImpl.REPLY_CUSTOM_DATA_FIELD)).put("@type", "type.googleapis.com/" + config.replyClass.getName()); Method m = config.replyBuilder; Message.Builder o = (Message.Builder) m.invoke(null); registry = TypeRegistry.newBuilder().add(o.getDescriptorForType()).build(); } ClassificationReply.Builder builder = ClassificationReply.newBuilder(); JsonFormat.Parser jFormatter = JsonFormat.parser(); if (registry != null) jFormatter = jFormatter.usingTypeRegistry(registry); jFormatter.merge(json.toString(), builder); ClassificationReply reply = builder.build(); return reply; } catch (Exception e) { logger.error("Failed to convert json "+json.toString()+" to PredictReply",e); return null; } } private ClassificationRequest getPredictRequestWithCustomDefaultFromJSON(JsonNode json) throws InvalidProtocolBufferException { ObjectMapper mapper = new ObjectMapper(); ObjectNode data = mapper.createObjectNode(); data.put("@type", "type.googleapis.com/" + DefaultCustomPredictRequest.class.getName()); data.put("values", json.get(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD)); ((ObjectNode) json).put(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD, data); Message.Builder o = DefaultCustomPredictRequest.newBuilder(); TypeRegistry registry = TypeRegistry.newBuilder().add(o.getDescriptorForType()).build(); ClassificationRequest.Builder builder = ClassificationRequest.newBuilder(); JsonFormat.Parser jFormatter = JsonFormat.parser(); if (registry != null) jFormatter = jFormatter.usingTypeRegistry(registry); jFormatter.merge(json.toString(), builder); ClassificationRequest request = builder.build(); return request; } public ClassificationRequest getPredictRequestFromJson(String client,JsonNode json) { RPCConfig config = services.get(client); if (config != null) { try { TypeRegistry registry = null; if (config.requestClass != null && json.has(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD)) { if (!json.get(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD).has("@type")) ((ObjectNode) json.get(PredictionBusinessServiceImpl.REQUEST_CUSTOM_DATA_FIELD)).put("@type", "type.googleapis.com/" + config.requestClass.getName()); Method m = config.requestBuilder; Message.Builder o = (Message.Builder) m.invoke(null); registry = TypeRegistry.newBuilder().add(o.getDescriptorForType()).build(); } ClassificationRequest.Builder builder = ClassificationRequest.newBuilder(); JsonFormat.Parser jFormatter = JsonFormat.parser(); if (registry != null) jFormatter = jFormatter.usingTypeRegistry(registry); jFormatter.merge(json.toString(), builder); ClassificationRequest request = builder.build(); return request; } catch (Exception e) { logger.error("Failed to convert json "+json.toString()+" to PredictRequest",e); return null; } } else { try { return this.getPredictRequestWithCustomDefaultFromJSON(json); } catch (Exception e) { logger.error("Failed to convert json "+json.toString()+" to PredictRequest using Default",e); return null; } } } void add(String client,Class<?> requestClass,Class<?> responseClass,Method requestBuilder,Method replyBuilder) { RPCConfig config = new RPCConfig(); config.requestClass = requestClass; config.replyClass = responseClass; config.requestBuilder = requestBuilder; config.replyBuilder = replyBuilder; services.put(client, config); } private void createClientConfig(String client,String data) { try { ObjectMapper mapper = new ObjectMapper(); RPCZkConfig config = mapper.readValue(data, RPCZkConfig.class); File f = new File(config.jarFilename); try { URL myURL = f.toURI().toURL(); URL[] urls = {myURL}; URLClassLoader cLoader = new URLClassLoader (urls, this.getClass().getClassLoader()); Class<?> requestClass = null; Class<?> responseClass = null; Method requestBuilder = null; Method replyBuilder = null; if (org.apache.commons.lang.StringUtils.isNotEmpty(config.requestClassName)) { requestClass = Class.forName(config.requestClassName,true,cLoader); requestBuilder = requestClass.getMethod("newBuilder"); } if (org.apache.commons.lang.StringUtils.isNotEmpty(config.replyClassName)) { responseClass = Class.forName(config.replyClassName,true,cLoader); replyBuilder = requestClass.getMethod("newBuilder"); } this.add(client, requestClass, responseClass,requestBuilder,replyBuilder); } catch (MalformedURLException e) { logger.error("Bad url "+config.jarFilename,e); } catch (ClassNotFoundException e) { logger.error("Failed to load class ",e); } catch (NoSuchMethodException e) { logger.error("Failed to load class ",e); } catch (SecurityException e) { logger.error("Failed to load class ",e); } finally{} } catch (JsonParseException e1) { logger.error("Failed to parse json "+data,e1); } catch (JsonMappingException e1) { logger.error("Failed to parse json "+data,e1); } catch (IOException e1) { logger.error("Failed to parse json "+data,e1); } finally{} } @Override public void configUpdated(String client, String configKey, String configValue) { if (configKey.equals(RPC_KEY)) { logger.info("New client location "+client+" at "+configKey+" value "+configValue); if (org.apache.commons.lang.StringUtils.isNotEmpty(configValue)) { this.createClientConfig(client, configValue); } else { logger.warn("Ignoring as no data provided "+configValue+" for client "+client); } } } @Override public void configRemoved(String client, String configKey) { if (configKey.equals(RPC_KEY)) { logger.info("Removing client "+client); services.remove(client); } } public static class RPCConfigMap { public ConcurrentHashMap<String, RPCConfig> nameMap; public RPCConfigMap() { this.nameMap = new ConcurrentHashMap<>(); } } public static class RPCConfig { public Class<?> requestClass; public Method requestBuilder; public Class<?> replyClass; public Method replyBuilder; } public static class RPCZkConfig { public String jarFilename; public String requestClassName; public String replyClassName; } }