/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tajo.rpc; import com.google.protobuf.BlockingRpcChannel; import com.google.protobuf.Descriptors.MethodDescriptor; import com.google.protobuf.Message; import com.google.protobuf.RpcController; import com.google.protobuf.ServiceException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.tajo.rpc.RpcProtos.RpcRequest; import org.apache.tajo.rpc.RpcProtos.RpcResponse; import org.apache.tajo.util.NetUtils; import org.jboss.netty.channel.*; import org.jboss.netty.channel.socket.ClientSocketChannelFactory; import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.util.Map; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import static org.apache.tajo.rpc.RpcConnectionPool.RpcConnectionKey; public class BlockingRpcClient extends NettyClientBase { private static final Log LOG = LogFactory.getLog(RpcProtos.class); private final ChannelUpstreamHandler handler; private final ChannelPipelineFactory pipeFactory; private final ProxyRpcChannel rpcChannel; private final AtomicInteger sequence = new AtomicInteger(0); private final Map<Integer, ProtoCallFuture> requests = new ConcurrentHashMap<Integer, ProtoCallFuture>(); private final Class<?> protocol; private final Method stubMethod; private RpcConnectionKey key; /** * Intentionally make this method package-private, avoiding user directly * new an instance through this constructor. */ BlockingRpcClient(final Class<?> protocol, final InetSocketAddress addr) throws Exception { this(protocol, addr, RpcChannelFactory.getSharedClientChannelFactory()); } BlockingRpcClient(final Class<?> protocol, final InetSocketAddress addr, ClientSocketChannelFactory factory) throws Exception { this.protocol = protocol; String serviceClassName = protocol.getName() + "$" + protocol.getSimpleName() + "Service"; Class<?> serviceClass = Class.forName(serviceClassName); stubMethod = serviceClass.getMethod("newBlockingStub", BlockingRpcChannel.class); this.handler = new ClientChannelUpstreamHandler(); pipeFactory = new ProtoPipelineFactory(handler, RpcResponse.getDefaultInstance()); super.init(addr, pipeFactory, factory); rpcChannel = new ProxyRpcChannel(); this.key = new RpcConnectionKey(addr, protocol, false); } @Override public RpcConnectionKey getKey() { return key; } @Override public <T> T getStub() { try { return (T) stubMethod.invoke(null, rpcChannel); } catch (Exception e) { throw new RuntimeException(e.getMessage(), e); } } public BlockingRpcChannel getBlockingRpcChannel() { return this.rpcChannel; } private class ProxyRpcChannel implements BlockingRpcChannel { private final ClientChannelUpstreamHandler handler; public ProxyRpcChannel() { this.handler = getChannel().getPipeline(). get(ClientChannelUpstreamHandler.class); if (handler == null) { throw new IllegalArgumentException("Channel does not have " + "proper handler"); } } public Message callBlockingMethod(final MethodDescriptor method, final RpcController controller, final Message param, final Message responsePrototype) throws ServiceException { int nextSeqId = sequence.getAndIncrement(); Message rpcRequest = buildRequest(nextSeqId, method, param); ProtoCallFuture callFuture = new ProtoCallFuture(controller, responsePrototype); requests.put(nextSeqId, callFuture); getChannel().write(rpcRequest); try { return callFuture.get(); } catch (Throwable t) { if(t instanceof ExecutionException) { ExecutionException ee = (ExecutionException)t; throw new ServiceException(ee.getCause()); } else { throw new RemoteException(t); } } } private Message buildRequest(int seqId, MethodDescriptor method, Message param) { RpcRequest.Builder requestBuilder = RpcRequest.newBuilder() .setId(seqId) .setMethodName(method.getName()); if (param != null) { requestBuilder.setRequestMessage(param.toByteString()); } return requestBuilder.build(); } } private String getErrorMessage(String message) { if(protocol != null && getChannel() != null) { return "Exception [" + protocol.getCanonicalName() + "(" + NetUtils.normalizeInetSocketAddress((InetSocketAddress) getChannel().getRemoteAddress()) + ")]: " + message; } else { return "Exception " + message; } } private class ClientChannelUpstreamHandler extends SimpleChannelUpstreamHandler { @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { RpcResponse rpcResponse = (RpcResponse) e.getMessage(); ProtoCallFuture callback = requests.remove(rpcResponse.getId()); if (callback == null) { LOG.warn("Dangling rpc call"); } else { if (rpcResponse.hasErrorMessage()) { callback.setFailed(rpcResponse.getErrorMessage(), new ServiceException(getErrorMessage(rpcResponse.getErrorMessage()))); throw new RemoteException( getErrorMessage(rpcResponse.getErrorMessage())); } else { Message responseMessage; if (!rpcResponse.hasResponseMessage()) { responseMessage = null; } else { responseMessage = callback.returnType.newBuilderForType(). mergeFrom(rpcResponse.getResponseMessage()).build(); } callback.setResponse(responseMessage); } } } @Override public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception { e.getChannel().close(); for(ProtoCallFuture callback: requests.values()) { callback.setFailed(e.getCause().getMessage(), e.getCause()); } if(LOG.isDebugEnabled()) { LOG.error("" + e.getCause().getMessage(), e.getCause()); } else { LOG.error("RPC Exception:" + e.getCause().getMessage()); } } } class ProtoCallFuture implements Future<Message> { private Semaphore sem = new Semaphore(0); private Message response = null; private Message returnType; private RpcController controller; private ExecutionException ee; public ProtoCallFuture(RpcController controller, Message message) { this.controller = controller; this.returnType = message; } @Override public boolean cancel(boolean arg0) { return false; } @Override public Message get() throws InterruptedException, ExecutionException { sem.acquire(); if(ee != null) { throw ee; } return response; } @Override public Message get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { if(sem.tryAcquire(timeout, unit)) { return response; } else { throw new TimeoutException(); } } @Override public boolean isCancelled() { return false; } @Override public boolean isDone() { return sem.availablePermits() > 0; } public void setResponse(Message response) { this.response = response; sem.release(); } public void setFailed(String errorText, Throwable t) { if(controller != null) { this.controller.setFailed(errorText); } ee = new ExecutionException(errorText, t); sem.release(); } } }