/***********************************************************************************************************************
*
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
**********************************************************************************************************************/
package eu.stratosphere.nephele.rpc;
import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.DatagramPacket;
import java.net.InetSocketAddress;
import java.util.Iterator;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import com.esotericsoftware.minlog.Log;
import eu.stratosphere.util.KryoUtil;
import eu.stratosphere.util.StringUtils;
/**
* This class implements a lightweight, UDP-based RPC service.
* <p>
* This class is thread-safe.
*/
public final class RPCService {
/**
* The default number of threads handling RPC requests.
*/
private static final int DEFAULT_NUM_RPC_HANDLERS = 1;
/**
* Interval in which the background clean-up routine runs in milliseconds.
*/
static final int CLEANUP_INTERVAL = 10000;
/**
* The maximum period of time an RPC call is allowed to take in milliseconds.
*/
private static final int RPC_TIMEOUT = 60000;
/**
* The executor service managing the RPC handler threads.
*/
private final ExecutorService rpcHandlers;
/**
* The UDP port this service is bound to.
*/
private final int rpcPort;
/**
* Network thread to wait for incoming data and dispatch it among the available RPC handler threads.
*/
private final NetworkThread networkThread;
/**
* Stores whether the RPC service was requested to shut down.
*/
private final AtomicBoolean shutdownRequested = new AtomicBoolean(false);
/**
* The statistics module collects statistics on the operation of the RPC service.
*/
private final RPCStatistics statistics = new RPCStatistics();
/**
* Periodic timer to handle clean-up tasks in the background.
*/
private final Timer cleanupTimer = new Timer();
private final ConcurrentHashMap<String, RPCProtocol> callbackHandlers =
new ConcurrentHashMap<String, RPCProtocol>();
private final ConcurrentHashMap<Integer, RPCRequestMonitor> pendingRequests =
new ConcurrentHashMap<Integer, RPCRequestMonitor>();
private final ConcurrentHashMap<Integer, RPCRequest> requestsBeingProcessed =
new ConcurrentHashMap<Integer, RPCRequest>();
private final ConcurrentHashMap<Integer, CachedResponse> cachedResponses =
new ConcurrentHashMap<Integer, CachedResponse>();
public RPCService() throws IOException {
this(DEFAULT_NUM_RPC_HANDLERS);
}
public RPCService(final int numRPCHandlers) throws IOException {
this.rpcHandlers = Executors.newFixedThreadPool(numRPCHandlers);
this.rpcPort = -1;
this.networkThread = new NetworkThread(this, -1);
this.networkThread.start();
this.cleanupTimer.schedule(new CleanupTask(), CLEANUP_INTERVAL, CLEANUP_INTERVAL);
}
public RPCService(final int rpcPort, final int numRPCHandlers) throws IOException {
this.rpcHandlers = Executors.newFixedThreadPool(numRPCHandlers);
this.rpcPort = rpcPort;
this.networkThread = new NetworkThread(this, rpcPort);
this.networkThread.start();
this.cleanupTimer.schedule(new CleanupTask(), CLEANUP_INTERVAL, CLEANUP_INTERVAL);
}
@SuppressWarnings("unchecked")
public <T extends RPCProtocol> T getProxy(final InetSocketAddress remoteAddress, final Class<T> protocol) {
final Class<?>[] interfaces = new Class<?>[1];
interfaces[0] = protocol;
return (T) java.lang.reflect.Proxy.newProxyInstance(RPCService.class.getClassLoader(), interfaces,
new RPCInvocationHandler(remoteAddress, protocol.getName()));
}
public int getRPCPort() {
return this.rpcPort;
}
public void setProtocolCallbackHandler(final Class<? extends RPCProtocol> protocol,
final RPCProtocol callbackHandler) {
// Check signature of interface before adding it
checkRPCProtocol(protocol);
if (this.callbackHandlers.putIfAbsent(protocol.getName(), callbackHandler) != null)
Log.error("There is already a protocol call back handler set for protocol " + protocol.getName());
}
public void shutDown() {
if (!this.shutdownRequested.compareAndSet(false, true))
return;
// Request shutdown of network thread
try {
this.networkThread.shutdown();
} catch (final InterruptedException ie) {
Log.debug("Caught exception while waiting for network thread to shut down: ", ie);
}
this.rpcHandlers.shutdown();
try {
this.rpcHandlers.awaitTermination(5000L, TimeUnit.MILLISECONDS);
} catch (final InterruptedException ie) {
Log.debug("Caught exception while waiting for RPC handlers to finish: ", ie);
}
this.cleanupTimer.cancel();
// Finally, process the last collected data
this.statistics.processCollectedData();
}
void processIncomingRPCCleanup(final RPCCleanup rpcCleanup) {
this.cachedResponses.remove(Integer.valueOf(rpcCleanup.getMessageID()));
}
void processIncomingRPCMessage(final InetSocketAddress remoteSocketAddress, final Input input) {
final Runnable runnable = new Runnable() {
/**
* {@inheritDoc}
*/
@Override
public void run() {
final Kryo k = KryoUtil.getKryo();
k.reset();
final RPCEnvelope envelope = k.readObject(input, RPCEnvelope.class);
final RPCMessage msg = envelope.getRPCMessage();
if (msg instanceof RPCRequest)
RPCService.this.processIncomingRPCRequest(remoteSocketAddress, (RPCRequest) msg);
else if (msg instanceof RPCResponse)
RPCService.this.processIncomingRPCResponse((RPCResponse) msg);
else
RPCService.this.processIncomingRPCCleanup((RPCCleanup) msg);
}
};
this.rpcHandlers.execute(runnable);
}
/**
* Processes an incoming RPC response.
*
* @param rpcResponse
* the RPC response to be processed
*/
void processIncomingRPCResponse(final RPCResponse rpcResponse) {
final Integer messageID = Integer.valueOf(rpcResponse.getMessageID());
final RPCRequestMonitor requestMonitor = this.pendingRequests.get(messageID);
// The caller has already timed out or received an earlier response
if (requestMonitor == null)
return;
synchronized (requestMonitor) {
requestMonitor.rpcResponse = rpcResponse;
requestMonitor.notify();
}
}
/**
* Sends an RPC request to the given {@link InetSocketAddress}.
*
* @param remoteSocketAddress
* the remote address to send the request to
* @param request
* the RPC request to send
* @return the return value of the RPC call, possibly <code>null</code>
* @throws Throwable
* any exception that is thrown by the remote receiver of the RPC call
*/
Object sendRPCRequest(final InetSocketAddress remoteSocketAddress, final RPCRequest request) throws Throwable {
if (this.shutdownRequested.get())
throw new IOException("Shutdown of RPC service has already been requested");
final long start = System.currentTimeMillis();
final DatagramPacket[] packets = this.messageToPackets(remoteSocketAddress, request);
final Integer messageID = Integer.valueOf(request.getMessageID());
final RPCRequestMonitor requestMonitor = new RPCRequestMonitor();
this.pendingRequests.put(messageID, requestMonitor);
RPCResponse rpcResponse = null;
int numberOfRetries;
try {
numberOfRetries = this.networkThread.send(packets);
// Wait for the response
synchronized (requestMonitor) {
while (true) {
if (requestMonitor.rpcResponse != null) {
rpcResponse = requestMonitor.rpcResponse;
break;
}
final long sleepTime = RPC_TIMEOUT - (System.currentTimeMillis() - start);
if (sleepTime > 0L)
requestMonitor.wait(sleepTime);
else
break;
}
}
} finally {
// Request is no longer pending
this.pendingRequests.remove(messageID);
}
if (rpcResponse == null)
throw new IOException("Unable to complete RPC of method " + request.getMethodName() + " on "
+ remoteSocketAddress);
// Report the successful call to the statistics module
final String methodName = request.getMethodName();
this.statistics.reportSuccessfulTransmission(methodName, packets.length, numberOfRetries);
this.statistics.reportRTT(methodName, (int) (System.currentTimeMillis() - start));
// TODO: Send clean up message
if (rpcResponse instanceof RPCReturnValue)
return ((RPCReturnValue) rpcResponse).getRetVal();
throw ((RPCThrowable) rpcResponse).getThrowable();
}
/**
* Checks if the given class is registered with the RPC service.
*
* @param throwableType
* the class to check
* @return <code>true</code> if the given class is registered with the RPC service, <code>false</code> otherwise
*/
private boolean isThrowableRegistered(final Class<? extends Throwable> throwableType) {
final Kryo kryo = KryoUtil.getKryo();
try {
kryo.getRegistration(throwableType);
} catch (final IllegalArgumentException e) {
return false;
}
return true;
}
private DatagramPacket[] messageToPackets(final InetSocketAddress remoteSocketAddress, final RPCMessage rpcMessage) {
final MultiPacketOutputStream mpos = new MultiPacketOutputStream(RPCMessage.MAXIMUM_MSG_SIZE
+ RPCMessage.METADATA_SIZE);
final Kryo kryo = KryoUtil.getKryo();
kryo.reset();
final Output output = new Output(mpos);
kryo.writeObject(output, new RPCEnvelope(rpcMessage));
output.close();
mpos.close();
return mpos.createPackets(remoteSocketAddress);
}
private void processIncomingRPCRequest(final InetSocketAddress remoteSocketAddress, final RPCRequest rpcRequest) {
final Integer messageID = Integer.valueOf(rpcRequest.getMessageID());
if (this.requestsBeingProcessed.putIfAbsent(messageID, rpcRequest) != null) {
Log.debug("Request " + rpcRequest.getMessageID() + " is already being processed at the moment");
return;
}
final CachedResponse cachedResponse = this.cachedResponses.get(messageID);
if (cachedResponse != null) {
try {
final int numberOfRetries = this.networkThread.send(cachedResponse.packets);
this.statistics.reportSuccessfulTransmission(rpcRequest.getMethodName() + " (Response)",
cachedResponse.packets.length, numberOfRetries);
} catch (final Exception e) {
Log.error("Caught exception while trying to send RPC response: ", e);
} finally {
this.requestsBeingProcessed.remove(messageID);
}
return;
}
final RPCProtocol callbackHandler = this.callbackHandlers.get(rpcRequest.getInterfaceName());
if (callbackHandler == null) {
Log.error("Cannot find callback handler for protocol " + rpcRequest.getInterfaceName());
this.requestsBeingProcessed.remove(messageID);
return;
}
try {
final Method method = callbackHandler.getClass().getMethod(rpcRequest.getMethodName(),
rpcRequest.getParameterTypes());
RPCResponse rpcResponse = null;
try {
final Object retVal = method.invoke(callbackHandler, rpcRequest.getArgs());
rpcResponse = new RPCReturnValue(rpcRequest.getMessageID(), retVal);
} catch (final InvocationTargetException ite) {
Throwable targetException = ite.getTargetException();
// Make sure the stack trace is correctly filled
targetException.getStackTrace();
if (!this.isThrowableRegistered(targetException.getClass()))
targetException = wrapInIOException(rpcRequest, targetException);
rpcResponse = new RPCThrowable(rpcRequest.getMessageID(), targetException);
}
final DatagramPacket[] packets = this.messageToPackets(remoteSocketAddress, rpcResponse);
this.cachedResponses.put(messageID, new CachedResponse(System.currentTimeMillis(), packets));
final int numberOfRetries = this.networkThread.send(packets);
this.statistics.reportSuccessfulTransmission(rpcRequest.getMethodName() + " (Response)", packets.length,
numberOfRetries);
} catch (final Exception e) {
Log.error("Caught processing RPC request: ", e);
} finally {
this.requestsBeingProcessed.remove(messageID);
}
}
/**
* Converts the unsigned short into an integer
*
* @param val
* the unsigned short
* @return the converted integer
*/
static int decodeInteger(final short val) {
return val - Short.MIN_VALUE - 1;
}
/**
* Converts the given integer to a unsigned short.
*
* @param val
* the integer to convert
* @return the unsigned short
*/
static short encodeInteger(final int val) {
if (val < -1 || val > 65534)
throw new IllegalArgumentException("Value must be in the range -1 and 65534 but is " + val);
return (short) (val - Short.MIN_VALUE + 1);
}
/**
* Checks the signature of the methods contained in the given protocol.
*
* @param protocol
* the protocol to be checked
*/
private static final void checkRPCProtocol(final Class<? extends RPCProtocol> protocol) {
if (!protocol.isInterface())
throw new IllegalArgumentException("Provided protocol " + protocol + " is not an interface");
try {
final Method[] methods = protocol.getMethods();
for (int i = 0; i < methods.length; ++i) {
final Method method = methods[i];
final Class<?>[] exceptionTypes = method.getExceptionTypes();
boolean ioExceptionFound = false;
boolean interruptedExceptionFound = false;
for (int j = 0; j < exceptionTypes.length; ++j)
if (IOException.class.equals(exceptionTypes[j]))
ioExceptionFound = true;
else if (InterruptedException.class.equals(exceptionTypes[j]))
interruptedExceptionFound = true;
if (!ioExceptionFound)
throw new IllegalArgumentException("Method " + method.getName()
+ " of protocol " + protocol.getName() + " must be declared to throw an IOException");
if (!interruptedExceptionFound)
throw new IllegalArgumentException("Method " + method.getName()
+ " of protocol " + protocol.getName() + " must be declared to throw an InterruptedException");
}
} catch (final SecurityException se) {
if (Log.DEBUG)
Log.debug(StringUtils.stringifyException(se));
}
}
/**
* Transforms the given {@link Throwable} into a string and wraps it into an {@link IOException}.
*
* @param request
* the RPC request which caused the {@link Throwable} to be wrapped
* @param throwable
* the {@link Throwable} to be wrapped
* @return the {@link} IOException created from the {@link Throwable}
*/
private static IOException wrapInIOException(final RPCRequest request, final Throwable throwable) {
final StringBuilder sb = new StringBuilder("The remote procedure call of method ");
sb.append(request.getInterfaceName());
sb.append('.');
sb.append(request.getMethodName());
sb.append(" caused an unregistered exception: ");
sb.append(StringUtils.stringifyException(throwable));
return new IOException(sb.toString());
}
private static final class CachedResponse {
private final long creationTime;
private final DatagramPacket[] packets;
private CachedResponse(final long creationTime, final DatagramPacket[] packets) {
this.creationTime = creationTime;
this.packets = packets;
}
}
private final class CleanupTask extends TimerTask {
/**
* {@inheritDoc}
*/
@Override
public void run() {
// Process the collected data
RPCService.this.statistics.processCollectedData();
final long now = System.currentTimeMillis();
final Iterator<Map.Entry<Integer, CachedResponse>> it =
RPCService.this.cachedResponses.entrySet().iterator();
while (it.hasNext()) {
final Map.Entry<Integer, CachedResponse> entry = it.next();
final CachedResponse cachedResponse = entry.getValue();
if (cachedResponse.creationTime + CLEANUP_INTERVAL < now)
it.remove();
}
RPCService.this.networkThread.cleanUpStaleState();
}
}
private final class RPCInvocationHandler implements InvocationHandler {
private final InetSocketAddress remoteSocketAddress;
private final String interfaceName;
private RPCInvocationHandler(final InetSocketAddress remoteSocketAddress, final String interfaceName) {
this.remoteSocketAddress = remoteSocketAddress;
this.interfaceName = interfaceName;
}
/**
* {@inheritDoc}
*/
@Override
public Object invoke(final Object proxy, final Method method, final Object[] args) throws Throwable {
final int messageID = (int) (Integer.MIN_VALUE + Math.random() * Integer.MAX_VALUE * 2.0);
final RPCRequest rpcRequest = new RPCRequest(messageID, this.interfaceName, method, args);
return RPCService.this.sendRPCRequest(this.remoteSocketAddress, rpcRequest);
}
}
private static final class RPCRequestMonitor {
private RPCResponse rpcResponse = null;
}
}