package net.notdot.protorpc; import java.net.SocketAddress; import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipelineCoverage; import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelHandler; import org.jboss.netty.channel.group.ChannelGroup; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.RpcCallback; import com.google.protobuf.Service; import com.google.protobuf.Message; import com.google.protobuf.Descriptors.MethodDescriptor; @ChannelPipelineCoverage("one") public class ProtoRpcHandler extends SimpleChannelHandler { protected class ProtoRpcCallback implements RpcCallback<Message> { protected ProtoRpcController controller; protected ProtoRpcCallback(ProtoRpcController controller) { this.controller = controller; } public void run(Message arg0) { Rpc.Response response = Rpc.Response.newBuilder() .setRpcId(this.controller.getRpcId()) .setStatus(Rpc.Response.ResponseType.OK) .setBody(arg0.toByteString()).build(); this.controller.sendResponse(response); } } final Logger logger = LoggerFactory.getLogger(ProtoRpcHandler.class); protected Service service; protected ChannelGroup open_channels; @Override public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { open_channels.add(e.getChannel()); logger.info("New connection from {}.", ctx.getChannel().getRemoteAddress()); } @Override public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { if(service instanceof Disposable) ((Disposable)service).close(); logger.info("Client {} disconnected.", ctx.getChannel().getRemoteAddress()); } public ProtoRpcHandler(Service s, ChannelGroup channels) { this.service = s; this.open_channels = channels; } @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { Channel ch = e.getChannel(); SocketAddress remote_addr = ch.getRemoteAddress(); Rpc.Request request = (Rpc.Request) e.getMessage(); // For now we ignore the service name in the message. MethodDescriptor method = service.getDescriptorForType().findMethodByName(request.getMethod()); if(method == null) { ch.write(Rpc.Response.newBuilder().setStatus(Rpc.Response.ResponseType.CALL_NOT_FOUND).build()); return; } Message request_data; try { request_data = service.getRequestPrototype(method).newBuilderForType().mergeFrom(request.getBody()).build(); } catch(InvalidProtocolBufferException ex) { ch.write(Rpc.Response.newBuilder().setStatus(Rpc.Response.ResponseType.ARGUMENT_ERROR).build()); return; } logger.debug("Client {} RPC {} request data: {}", new Object[] { remote_addr, request.getRpcId(), request_data }); ProtoRpcController controller = new ProtoRpcController(ch, request.getService(), request.getMethod(), request.getRpcId()); ProtoRpcCallback callback = new ProtoRpcCallback(controller); try { service.callMethod(method, controller, request_data, callback); } catch(RpcFailedError ex) { if(ex.getCause() != null) { logger.error("Internal error: ", ex.getCause()); } controller.setFailed(ex.getMessage(), ex.getApplicationError()); } catch(Throwable t) { logger.error("Internal error: ", t); controller.setFailed(t.getMessage()); } if(!controller.isResponseSent()) { controller.sendResponse(Rpc.Response.newBuilder() .setRpcId(request.getRpcId()) .setStatus(Rpc.Response.ResponseType.RPC_FAILED) .setErrorDetail("RPC handler failed to issue a response").build()); logger.error("Client {} RPC {} failed to return a response.", new Object[] { remote_addr, request.getRpcId() }); } } }