package io.seldon.rpc; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import javax.annotation.PostConstruct; import org.apache.log4j.Logger; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.stereotype.Component; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.seldon.api.state.ClientConfigHandler; import io.seldon.api.state.ClientConfigUpdateListener; import io.seldon.api.state.GlobalConfigHandler; import io.seldon.api.state.PredictionAlgorithmStore; /** * RPC Channel Handler. Basic implementation that create a new Channel for each client until new zookeeper configuration appears which * might suggest a new service has been created and thus we should refresh the channel. This is required as DNS changes will be the channel might * keep failing and not get refreshed in present gRPC setup. See https://github.com/grpc/grpc-java/issues/1463 * @author clive * */ @Component public class SimpleRpcChannelHandlerImpl implements RpcChannelHandler,ClientConfigUpdateListener{ private static Logger logger = Logger.getLogger(SimpleRpcChannelHandlerImpl.class.getName()); ConcurrentHashMap<String,ConcurrentHashMap<String,ManagedChannel>> channels = new ConcurrentHashMap<String, ConcurrentHashMap<String,ManagedChannel>>(); private final ClientConfigHandler configHandler; @Autowired public SimpleRpcChannelHandlerImpl(ClientConfigHandler configHandler) { this.configHandler = configHandler; } @PostConstruct private void init(){ logger.info("Initializing..."); configHandler.addListener(this); } private String getKey(String host,int port) { return host+":"+port; } private ManagedChannel addChannel(String client,String host,int port) { ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port) // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid // needing certificates. .usePlaintext(true) .build(); ConcurrentHashMap<String,ManagedChannel> clientChannels = channels.putIfAbsent(client, new ConcurrentHashMap<String,ManagedChannel>()); clientChannels.put(getKey(host,port), channel); channels.put(client, clientChannels); return channel; } private void clearChannels(String client) { Map<String,ManagedChannel> currentChannels = channels.put(client, new ConcurrentHashMap<String,ManagedChannel>()); if (currentChannels != null) for (ManagedChannel ch : currentChannels.values()) { ch.shutdown(); } } @Override public ManagedChannel getChannel(String client,String host,int port) { String key = host+":"+port; if (channels.containsKey(client) && channels.get(client).containsKey(key)) { return channels.get(client).get(key); } else { return addChannel(client, host, port); } } @Override public void configUpdated(String client, String configKey, String configValue) { if (configKey.equals(PredictionAlgorithmStore.ALG_KEY)) { logger.info("Clearing existing channels for client "+client); clearChannels(client); } } @Override public void configRemoved(String client, String configKey) { if (configKey.equals(PredictionAlgorithmStore.ALG_KEY)) { logger.info("Clearing existing channels for client "+client); clearChannels(client); } } }