/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.rpc;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.concurrent.CompletableFuture;
import org.apache.flink.runtime.concurrent.Future;
import org.apache.flink.runtime.concurrent.ScheduledExecutor;
import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter;
import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture;
import org.apache.flink.runtime.util.DirectExecutorService;
import org.apache.flink.util.Preconditions;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.BitSet;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Delayed;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* An RPC Service implementation for testing. This RPC service directly executes all asynchronous
* calls one by one in the calling thread.
*/
public class TestingSerialRpcService implements RpcService {
private final DirectExecutorService executorService;
private final ScheduledExecutorService scheduledExecutorService;
private final ConcurrentHashMap<String, RpcGateway> registeredConnections;
private final CompletableFuture<Void> terminationFuture;
private final ScheduledExecutor scheduledExecutorServiceAdapter;
public TestingSerialRpcService() {
executorService = new DirectExecutorService();
scheduledExecutorService = new ScheduledThreadPoolExecutor(1);
this.registeredConnections = new ConcurrentHashMap<>(16);
this.terminationFuture = new FlinkCompletableFuture<>();
this.scheduledExecutorServiceAdapter = new ScheduledExecutorServiceAdapter(scheduledExecutorService);
}
@Override
public ScheduledFuture<?> scheduleRunnable(final Runnable runnable, final long delay, final TimeUnit unit) {
try {
unit.sleep(delay);
runnable.run();
return new DoneScheduledFuture<Void>(null);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
@Override
public void execute(Runnable runnable) {
runnable.run();
}
@Override
public <T> Future<T> execute(Callable<T> callable) {
try {
T result = callable.call();
return FlinkCompletableFuture.completed(result);
} catch (Exception e) {
return FlinkCompletableFuture.completedExceptionally(e);
}
}
@Override
public Executor getExecutor() {
return executorService;
}
public ScheduledExecutor getScheduledExecutor() {
return scheduledExecutorServiceAdapter;
}
@Override
public void stopService() {
executorService.shutdown();
scheduledExecutorService.shutdown();
boolean terminated = false;
try {
terminated = scheduledExecutorService.awaitTermination(1, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
if (!terminated) {
List<Runnable> runnables = scheduledExecutorService.shutdownNow();
for (Runnable runnable : runnables) {
runnable.run();
}
}
registeredConnections.clear();
terminationFuture.complete(null);
}
@Override
public Future<Void> getTerminationFuture() {
return terminationFuture;
}
@Override
public void stopServer(RpcGateway selfGateway) {
registeredConnections.remove(selfGateway.getAddress());
}
@Override
public <C extends RpcGateway, S extends RpcEndpoint<C>> C startServer(S rpcEndpoint) {
final String address = UUID.randomUUID().toString();
InvocationHandler akkaInvocationHandler = new TestingSerialRpcService.TestingSerialInvocationHandler<>(address, rpcEndpoint);
ClassLoader classLoader = getClass().getClassLoader();
@SuppressWarnings("unchecked")
C self = (C) Proxy.newProxyInstance(
classLoader,
new Class<?>[]{
rpcEndpoint.getSelfGatewayType(),
MainThreadExecutable.class,
StartStoppable.class,
RpcGateway.class
},
akkaInvocationHandler);
// register self
registeredConnections.putIfAbsent(self.getAddress(), self);
return self;
}
@Override
public String getAddress() {
return "";
}
@Override
public <C extends RpcGateway> Future<C> connect(String address, Class<C> clazz) {
RpcGateway gateway = registeredConnections.get(address);
if (gateway != null) {
if (clazz.isAssignableFrom(gateway.getClass())) {
@SuppressWarnings("unchecked")
C typedGateway = (C) gateway;
return FlinkCompletableFuture.completed(typedGateway);
} else {
return FlinkCompletableFuture.completedExceptionally(
new Exception("Gateway registered under " + address + " is not of type " + clazz));
}
} else {
return FlinkCompletableFuture.completedExceptionally(new Exception("No gateway registered under that name"));
}
}
// ------------------------------------------------------------------------
// connections
// ------------------------------------------------------------------------
public void registerGateway(String address, RpcGateway gateway) {
checkNotNull(address);
checkNotNull(gateway);
if (registeredConnections.putIfAbsent(address, gateway) != null) {
throw new IllegalStateException("a gateway is already registered under " + address);
}
}
public void clearGateways() {
registeredConnections.clear();
}
private static final class TestingSerialInvocationHandler<C extends RpcGateway, T extends RpcEndpoint<C>> implements InvocationHandler, RpcGateway, MainThreadExecutable, StartStoppable {
private final T rpcEndpoint;
/** default timeout for asks */
private final Time timeout;
private final String address;
private TestingSerialInvocationHandler(String address, T rpcEndpoint) {
this(address, rpcEndpoint, Time.seconds(10));
}
private TestingSerialInvocationHandler(String address, T rpcEndpoint, Time timeout) {
this.rpcEndpoint = rpcEndpoint;
this.timeout = timeout;
this.address = address;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
Class<?> declaringClass = method.getDeclaringClass();
if (declaringClass.equals(MainThreadExecutable.class) ||
declaringClass.equals(Object.class) || declaringClass.equals(StartStoppable.class) ||
declaringClass.equals(RpcGateway.class)) {
return method.invoke(this, args);
} else {
final String methodName = method.getName();
Class<?>[] parameterTypes = method.getParameterTypes();
Annotation[][] parameterAnnotations = method.getParameterAnnotations();
Time futureTimeout = extractRpcTimeout(parameterAnnotations, args, timeout);
final Tuple2<Class<?>[], Object[]> filteredArguments = filterArguments(
parameterTypes,
parameterAnnotations,
args);
Class<?> returnType = method.getReturnType();
if (returnType.equals(Future.class)) {
try {
Object result = handleRpcInvocationSync(methodName, filteredArguments.f0, filteredArguments.f1, futureTimeout);
return FlinkCompletableFuture.completed(result);
} catch (Throwable e) {
return FlinkCompletableFuture.completedExceptionally(e);
}
} else {
return handleRpcInvocationSync(methodName, filteredArguments.f0, filteredArguments.f1, futureTimeout);
}
}
}
/**
* Handle rpc invocations by looking up the rpc method on the rpc endpoint and calling this
* method with the provided method arguments. If the method has a return value, it is returned
* to the sender of the call.
*/
private Object handleRpcInvocationSync(final String methodName,
final Class<?>[] parameterTypes,
final Object[] args,
final Time futureTimeout) throws Exception {
final Method rpcMethod = lookupRpcMethod(methodName, parameterTypes);
Object result = rpcMethod.invoke(rpcEndpoint, args);
if (result instanceof Future) {
Future<?> future = (Future<?>) result;
return future.get(futureTimeout.getSize(), futureTimeout.getUnit());
} else {
return result;
}
}
@Override
public void runAsync(Runnable runnable) {
runnable.run();
}
@Override
public <V> Future<V> callAsync(Callable<V> callable, Time callTimeout) {
try {
return FlinkCompletableFuture.completed(callable.call());
} catch (Throwable e) {
return FlinkCompletableFuture.completedExceptionally(e);
}
}
@Override
public void scheduleRunAsync(final Runnable runnable, final long delay) {
try {
TimeUnit.MILLISECONDS.sleep(delay);
runnable.run();
} catch (Throwable e) {
throw new RuntimeException(e);
}
}
@Override
public String getAddress() {
return address;
}
// this is not a real hostname but the address above is also not a real akka RPC address
// and we keep it that way until actually needed by a test case
@Override
public String getHostname() {
return address;
}
@Override
public void start() {
// do nothing
}
@Override
public void stop() {
// do nothing
}
/**
* Look up the rpc method on the given {@link RpcEndpoint} instance.
*
* @param methodName Name of the method
* @param parameterTypes Parameter types of the method
* @return Method of the rpc endpoint
* @throws NoSuchMethodException Thrown if the method with the given name and parameter types
* cannot be found at the rpc endpoint
*/
private Method lookupRpcMethod(final String methodName,
final Class<?>[] parameterTypes) throws NoSuchMethodException {
return rpcEndpoint.getClass().getMethod(methodName, parameterTypes);
}
// ------------------------------------------------------------------------
// Helper methods
// ------------------------------------------------------------------------
/**
* Extracts the {@link RpcTimeout} annotated rpc timeout value from the list of given method
* arguments. If no {@link RpcTimeout} annotated parameter could be found, then the default
* timeout is returned.
*
* @param parameterAnnotations Parameter annotations
* @param args Array of arguments
* @param defaultTimeout Default timeout to return if no {@link RpcTimeout} annotated parameter
* has been found
* @return Timeout extracted from the array of arguments or the default timeout
*/
private static Time extractRpcTimeout(Annotation[][] parameterAnnotations, Object[] args,
Time defaultTimeout) {
if (args != null) {
Preconditions.checkArgument(parameterAnnotations.length == args.length);
for (int i = 0; i < parameterAnnotations.length; i++) {
if (isRpcTimeout(parameterAnnotations[i])) {
if (args[i] instanceof Time) {
return (Time) args[i];
} else {
throw new RuntimeException("The rpc timeout parameter must be of type " +
Time.class.getName() + ". The type " + args[i].getClass().getName() +
" is not supported.");
}
}
}
}
return defaultTimeout;
}
/**
* Removes all {@link RpcTimeout} annotated parameters from the parameter type and argument
* list.
*
* @param parameterTypes Array of parameter types
* @param parameterAnnotations Array of parameter annotations
* @param args Arary of arguments
* @return Tuple of filtered parameter types and arguments which no longer contain the
* {@link RpcTimeout} annotated parameter types and arguments
*/
private static Tuple2<Class<?>[], Object[]> filterArguments(
Class<?>[] parameterTypes,
Annotation[][] parameterAnnotations,
Object[] args) {
Class<?>[] filteredParameterTypes;
Object[] filteredArgs;
if (args == null) {
filteredParameterTypes = parameterTypes;
filteredArgs = null;
} else {
Preconditions.checkArgument(parameterTypes.length == parameterAnnotations.length);
Preconditions.checkArgument(parameterAnnotations.length == args.length);
BitSet isRpcTimeoutParameter = new BitSet(parameterTypes.length);
int numberRpcParameters = parameterTypes.length;
for (int i = 0; i < parameterTypes.length; i++) {
if (isRpcTimeout(parameterAnnotations[i])) {
isRpcTimeoutParameter.set(i);
numberRpcParameters--;
}
}
if (numberRpcParameters == parameterTypes.length) {
filteredParameterTypes = parameterTypes;
filteredArgs = args;
} else {
filteredParameterTypes = new Class<?>[numberRpcParameters];
filteredArgs = new Object[numberRpcParameters];
int counter = 0;
for (int i = 0; i < parameterTypes.length; i++) {
if (!isRpcTimeoutParameter.get(i)) {
filteredParameterTypes[counter] = parameterTypes[i];
filteredArgs[counter] = args[i];
counter++;
}
}
}
}
return Tuple2.of(filteredParameterTypes, filteredArgs);
}
/**
* Checks whether any of the annotations is of type {@link RpcTimeout}
*
* @param annotations Array of annotations
* @return True if {@link RpcTimeout} was found; otherwise false
*/
private static boolean isRpcTimeout(Annotation[] annotations) {
for (Annotation annotation : annotations) {
if (annotation.annotationType().equals(RpcTimeout.class)) {
return true;
}
}
return false;
}
}
private static class DoneScheduledFuture<V> implements ScheduledFuture<V> {
private final V value;
private DoneScheduledFuture(V value) {
this.value = value;
}
@Override
public long getDelay(TimeUnit unit) {
return 0L;
}
@Override
public int compareTo(Delayed o) {
return 0;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return false;
}
@Override
public boolean isCancelled() {
return false;
}
@Override
public boolean isDone() {
return true;
}
@Override
public V get() throws InterruptedException, ExecutionException {
return value;
}
@Override
public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
return value;
}
}
}