package org.deephacks.westty.protobuf; import com.google.common.base.Preconditions; import com.google.protobuf.MessageLite; import org.deephacks.westty.protobuf.FailureMessages.Failure; import org.deephacks.westty.config.ProtobufConfig; import org.deephacks.westty.spi.IoExecutors; import org.jboss.netty.bootstrap.ClientBootstrap; import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelEvent; import org.jboss.netty.channel.ChannelFactory; import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelHandler.Sharable; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelPipelineFactory; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.Channels; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelHandler; import org.jboss.netty.channel.group.ChannelGroup; import org.jboss.netty.channel.group.DefaultChannelGroup; import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory; import org.jboss.netty.handler.codec.frame.LengthFieldBasedFrameDecoder; import org.jboss.netty.handler.codec.frame.LengthFieldPrepender; import org.jboss.netty.handler.codec.oneone.OneToOneDecoder; import org.jboss.netty.handler.codec.oneone.OneToOneEncoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.inject.Inject; import javax.inject.Singleton; import java.io.IOException; import java.net.ConnectException; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import static org.jboss.netty.buffer.ChannelBuffers.wrappedBuffer; @Singleton public class ProtobufClient { private static final Logger log = LoggerFactory.getLogger(ProtobufClient.class); public static final int MESSAGE_LENGTH = 4; public static final int MESSAGE_MAX_SIZE_10MB = 10485760; private final ProtobufSerializer serializer; private final Lock lock = new ReentrantLock(); private final ConcurrentHashMap<Integer, ConcurrentLinkedQueue<Callback>> callbacks = new ConcurrentHashMap<>(); private final ChannelFactory factory; private final ChannelGroup channelGroup = new DefaultChannelGroup("channels"); private final ClientProtobufDecoder decoder; private final ClientProtobufEncoder encoder; private final LengthFieldPrepender lengthPrepender = new LengthFieldPrepender(MESSAGE_LENGTH); private final ClientHandler clientHandler = new ClientHandler(); private final ProtobufConfig config; @Inject public ProtobufClient(IoExecutors excutors, ProtobufSerializer serializer, ProtobufConfig config) { this.factory = new NioClientSocketChannelFactory(excutors.getBoss(), excutors.getWorker()); this.serializer = serializer; this.decoder = new ClientProtobufDecoder(serializer); this.encoder = new ClientProtobufEncoder(); this.config = config; } public int connect() throws IOException { return connect(new InetSocketAddress(config.getPort())); } public int connect(InetSocketAddress address) throws IOException { ClientBootstrap bootstrap = new ClientBootstrap(factory); bootstrap.setPipelineFactory(new ChannelPipelineFactory() { @Override public ChannelPipeline getPipeline() throws Exception { ChannelPipeline pipeline = Channels.pipeline(); pipeline.addLast("lengthFrameDecoder", new LengthFieldBasedFrameDecoder( MESSAGE_MAX_SIZE_10MB, 0, MESSAGE_LENGTH, 0, MESSAGE_LENGTH)); pipeline.addLast("decoder", decoder); pipeline.addLast("lengthPrepender", lengthPrepender); pipeline.addLast("encoder", encoder); pipeline.addLast("handler", clientHandler); return pipeline; } }); ChannelFuture future = bootstrap.connect(address); if (!future.awaitUninterruptibly().isSuccess()) { bootstrap.releaseExternalResources(); throw new IllegalArgumentException("Could not connect to " + address); } Channel channel = future.getChannel(); if (!channel.isConnected()) { bootstrap.releaseExternalResources(); throw new IllegalStateException("Channel could not connect to " + address); } channelGroup.add(channel); return channel.getId(); } public void registerResource(String protodesc) { serializer.registerResource(protodesc); } public ChannelFuture callAsync(Integer channelId, Object protoMsg) throws IOException { Channel channel = channelGroup.find(channelId); byte[] bytes = serializer.write(protoMsg); if (channel == null || !channel.isOpen()) { throw new IOException("Channel is not open"); } return channel.write(ChannelBuffers.wrappedBuffer(bytes)); } public Object callSync(int id, Object protoMsg) throws IOException, FailureMessageException { Preconditions.checkNotNull(protoMsg); Channel channel = channelGroup.find(id); if(channel == null){ throw new IllegalArgumentException("Channel not found ["+id+"]"); } byte[] bytes = serializer.write(protoMsg); Callback callback = new Callback(); lock.lock(); try { if (!channel.isOpen()) { throw new IOException("Channel is not open"); } ConcurrentLinkedQueue<Callback> callbackQueue = callbacks.get(id); if(callbackQueue == null){ callbackQueue = new ConcurrentLinkedQueue<>(); callbacks.put(id, callbackQueue); } callbackQueue.add(callback); channel.write(ChannelBuffers.wrappedBuffer(bytes)); } finally { lock.unlock(); } Object res = callback.get(); if (res instanceof Failure) { Failure failure = (Failure) res; throw new FailureMessageException(failure); } if(res instanceof VoidMessage.Void){ return null; } return res; } public void disconnect(Integer id) { Channel channel = channelGroup.find(id); if (channel != null && channel.isConnected()) { channel.close().awaitUninterruptibly(); } } public void shutdown() { channelGroup.close().awaitUninterruptibly(); final class ShutdownNetty extends Thread { public void run() { factory.releaseExternalResources(); } } new ShutdownNetty().start(); } class ClientHandler extends SimpleChannelHandler { @Override public void handleUpstream(final ChannelHandlerContext ctx, final ChannelEvent e) throws Exception { if (e instanceof ChannelStateEvent) { log.debug(e.toString()); } super.handleUpstream(ctx, e); } @Override public void channelClosed(ChannelHandlerContext ctx, final ChannelStateEvent e) throws Exception { disconnect(e.getChannel().getId()); } @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) { ConcurrentLinkedQueue<Callback> callbackQueue = callbacks.get(e.getChannel().getId()); if(callbackQueue == null){ return; } Callback callback = callbackQueue.poll(); if (callback != null) { callback.handle(e.getMessage()); } } @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { final Throwable cause = e.getCause(); final Channel ch = ctx.getChannel(); if (cause instanceof ClosedChannelException) { log.warn("Attempt to write to closed channel." + ch); disconnect(e.getChannel().getId()); } else if (cause instanceof IOException && "Connection reset by peer".equals(cause.getMessage())) { disconnect(e.getChannel().getId()); } else if (cause instanceof ConnectException && "Connection refused".equals(cause.getMessage())) { // server not up, nothing to do } else { log.error("Unexpected exception.", e.getCause()); disconnect(e.getChannel().getId()); } } } @Sharable static class ClientProtobufDecoder extends OneToOneDecoder { private final ProtobufSerializer serializer; public ClientProtobufDecoder(ProtobufSerializer serializer) { this.serializer = serializer; } @Override protected Object decode(ChannelHandlerContext ctx, Channel channel, Object msg) throws Exception { if (!(msg instanceof ChannelBuffer)) { return msg; } ChannelBuffer buf = (ChannelBuffer) msg; return serializer.read(buf.array()); } } @Sharable static class ClientProtobufEncoder extends OneToOneEncoder { @Override protected Object encode(ChannelHandlerContext ctx, Channel channel, Object msg) throws Exception { if (msg instanceof MessageLite) { return wrappedBuffer(((MessageLite) msg).toByteArray()); } if (msg instanceof MessageLite.Builder) { return wrappedBuffer(((MessageLite.Builder) msg).build().toByteArray()); } return msg; } } static class Callback { private final CountDownLatch latch = new CountDownLatch(1); private Object response; Object get() { try { latch.await(); } catch (InterruptedException e) { throw new RuntimeException(e); } return response; } void handle(Object response) { this.response = response; latch.countDown(); } } }