/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.runtime.rpc; import org.apache.flink.api.common.time.Time; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.util.ReflectionUtil; import org.apache.flink.util.TestLogger; import org.junit.Test; import org.reflections.Reflections; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Test which ensures that all classes of subtype {@link RpcEndpoint} implement * the methods specified in the generic gateway type argument. * * {@code * RpcEndpoint<GatewayTypeParameter extends RpcGateway> * } * * Note, that the class hierarchy can also be nested. In this case the type argument * always has to be the first argument, e.g. {@code * * // RpcClass needs to implement RpcGatewayClass' methods * RpcClass extends RpcEndpoint<RpcGatewayClass> * * // RpcClass2 or its subclass needs to implement RpcGatewayClass' methods * RpcClass<GatewayTypeParameter extends RpcGateway,...> extends RpcEndpoint<GatewayTypeParameter> * RpcClass2 extends RpcClass<RpcGatewayClass,...> * * // needless to say, this can even be nested further * ... * } * */ public class RpcCompletenessTest extends TestLogger { private static Logger LOG = LoggerFactory.getLogger(RpcCompletenessTest.class); private static final Class<?> futureClass = Future.class; private static final Class<?> timeoutClass = Time.class; @Test @SuppressWarnings({"rawtypes", "unchecked"}) public void testRpcCompleteness() { Reflections reflections = new Reflections("org.apache.flink"); Set<Class<? extends RpcEndpoint>> classes = reflections.getSubTypesOf(RpcEndpoint.class); Class<? extends RpcEndpoint> c; mainloop: for (Class<? extends RpcEndpoint> rpcEndpoint : classes) { c = rpcEndpoint; LOG.debug("-------------"); LOG.debug("c: {}", c); // skip abstract classes if (Modifier.isAbstract(c.getModifiers())) { LOG.debug("Skipping abstract class"); continue; } // check for type parameter bound to RpcGateway // skip if one is found because a subclass will provide the concrete argument TypeVariable<? extends Class<? extends RpcEndpoint>>[] typeParameters = c.getTypeParameters(); LOG.debug("Checking {} parameters.", typeParameters.length); for (int i = 0; i < typeParameters.length; i++) { for (Type bound : typeParameters[i].getBounds()) { LOG.debug("checking bound {} of type parameter {}", bound, typeParameters[i]); if (bound.toString().equals("interface " + RpcGateway.class.getName())) { if (i > 0) { fail("Type parameter for RpcGateway should come first in " + c); } LOG.debug("Skipping class with type parameter bound to RpcGateway."); // Type parameter is bound to RpcGateway which a subclass will provide continue mainloop; } } } // check if this class or any super class contains the RpcGateway argument Class<?> rpcGatewayType; do { LOG.debug("checking type argument of class: {}", c); rpcGatewayType = ReflectionUtil.getTemplateType1(c); LOG.debug("type argument is: {}", rpcGatewayType); c = (Class<? extends RpcEndpoint>) c.getSuperclass(); } while (!RpcGateway.class.isAssignableFrom(rpcGatewayType)); LOG.debug("Checking RRC completeness of endpoint '{}' with gateway '{}'", rpcEndpoint.getSimpleName(), rpcGatewayType.getSimpleName()); checkCompleteness(rpcEndpoint, (Class<? extends RpcGateway>) rpcGatewayType); } } @SuppressWarnings("rawtypes") private void checkCompleteness(Class<? extends RpcEndpoint> rpcEndpoint, Class<? extends RpcGateway> rpcGateway) { List<Method> rpcMethodsFromGateway = getRpcMethodsFromGateway(rpcGateway); Method[] gatewayMethods = rpcMethodsFromGateway.toArray(new Method[rpcMethodsFromGateway.size()]); Method[] serverMethods = rpcEndpoint.getMethods(); Map<String, Set<Method>> rpcMethods = new HashMap<>(); Set<Method> unmatchedRpcMethods = new HashSet<>(); for (Method serverMethod : serverMethods) { if (serverMethod.isAnnotationPresent(RpcMethod.class)) { if (rpcMethods.containsKey(serverMethod.getName())) { Set<Method> methods = rpcMethods.get(serverMethod.getName()); methods.add(serverMethod); rpcMethods.put(serverMethod.getName(), methods); } else { Set<Method> methods = new HashSet<>(); methods.add(serverMethod); rpcMethods.put(serverMethod.getName(), methods); } unmatchedRpcMethods.add(serverMethod); } } for (Method gatewayMethod : gatewayMethods) { assertTrue( "The rpc endpoint " + rpcEndpoint.getName() + " does not contain a RpcMethod " + "annotated method with the same name and signature " + generateEndpointMethodSignature(gatewayMethod) + ".", rpcMethods.containsKey(gatewayMethod.getName())); checkGatewayMethod(gatewayMethod); if (!matchGatewayMethodWithEndpoint(gatewayMethod, rpcMethods.get(gatewayMethod.getName()), unmatchedRpcMethods)) { fail("Could not find a RpcMethod annotated method in rpc endpoint " + rpcEndpoint.getName() + " matching the rpc gateway method " + generateEndpointMethodSignature(gatewayMethod) + " defined in the rpc gateway " + rpcGateway.getName() + "."); } } if (!unmatchedRpcMethods.isEmpty()) { StringBuilder builder = new StringBuilder(); for (Method unmatchedRpcMethod : unmatchedRpcMethods) { builder.append(unmatchedRpcMethod).append("\n"); } fail("The rpc endpoint " + rpcEndpoint.getName() + " contains rpc methods which " + "are not matched to gateway methods of " + rpcGateway.getName() + ":\n" + builder.toString()); } } /** * Checks whether the gateway method fulfills the gateway method requirements. * <ul> * <li>It checks whether the return type is void or a {@link Future} wrapping the actual result. </li> * <li>It checks that the method's parameter list contains at most one parameter annotated with {@link RpcTimeout}.</li> * </ul> * * @param gatewayMethod Gateway method to check */ private void checkGatewayMethod(Method gatewayMethod) { if (!gatewayMethod.getReturnType().equals(Void.TYPE)) { assertTrue( "The return type of method " + gatewayMethod.getName() + " in the rpc gateway " + gatewayMethod.getDeclaringClass().getName() + " is non void and not a " + "future. Non-void return types have to be returned as a future.", gatewayMethod.getReturnType().equals(futureClass)); } Annotation[][] parameterAnnotations = gatewayMethod.getParameterAnnotations(); Class<?>[] parameterTypes = gatewayMethod.getParameterTypes(); int rpcTimeoutParameters = 0; for (int i = 0; i < parameterAnnotations.length; i++) { if (RpcCompletenessTest.isRpcTimeout(parameterAnnotations[i])) { assertTrue( "The rpc timeout has to be of type " + timeoutClass.getName() + ".", parameterTypes[i].equals(timeoutClass)); rpcTimeoutParameters++; } } assertTrue("The gateway method " + gatewayMethod + " must have at most one RpcTimeout " + "annotated parameter.", rpcTimeoutParameters <= 1); } /** * Checks whether we find a matching overloaded version for the gateway method among the methods * with the same name in the rpc endpoint. * * @param gatewayMethod Gateway method * @param endpointMethods Set of rpc methods on the rpc endpoint with the same name as the gateway * method * @param unmatchedRpcMethods Set of unmatched rpc methods on the endpoint side (so far) */ private boolean matchGatewayMethodWithEndpoint(Method gatewayMethod, Set<Method> endpointMethods, Set<Method> unmatchedRpcMethods) { for (Method endpointMethod : endpointMethods) { if (checkMethod(gatewayMethod, endpointMethod)) { unmatchedRpcMethods.remove(endpointMethod); return true; } } return false; } private boolean checkMethod(Method gatewayMethod, Method endpointMethod) { Class<?>[] gatewayParameterTypes = gatewayMethod.getParameterTypes(); Annotation[][] gatewayParameterAnnotations = gatewayMethod.getParameterAnnotations(); Class<?>[] endpointParameterTypes = endpointMethod.getParameterTypes(); List<Class<?>> filteredGatewayParameterTypes = new ArrayList<>(); assertEquals(gatewayParameterTypes.length, gatewayParameterAnnotations.length); // filter out the RpcTimeout parameters for (int i = 0; i < gatewayParameterTypes.length; i++) { if (!RpcCompletenessTest.isRpcTimeout(gatewayParameterAnnotations[i])) { filteredGatewayParameterTypes.add(gatewayParameterTypes[i]); } } if (filteredGatewayParameterTypes.size() != endpointParameterTypes.length) { return false; } else { // check the parameter types for (int i = 0; i < filteredGatewayParameterTypes.size(); i++) { if (!checkType(filteredGatewayParameterTypes.get(i), endpointParameterTypes[i])) { return false; } } // check the return types if (endpointMethod.getReturnType() == void.class) { if (gatewayMethod.getReturnType() != void.class) { return false; } } else { // has return value. The gateway method should be wrapped in a future Class<?> futureClass = gatewayMethod.getReturnType(); // sanity check that the return type of a gateway method must be void or a future if (!futureClass.equals(RpcCompletenessTest.futureClass)) { return false; } else { ReflectionUtil.FullTypeInfo fullValueTypeInfo = ReflectionUtil.getFullTemplateType(gatewayMethod.getGenericReturnType(), 0); if (endpointMethod.getReturnType().equals(futureClass)) { ReflectionUtil.FullTypeInfo fullRpcEndpointValueTypeInfo = ReflectionUtil.getFullTemplateType(endpointMethod.getGenericReturnType(), 0); // check if we have the same future value types if (fullValueTypeInfo != null && fullRpcEndpointValueTypeInfo != null) { Iterator<Class<?>> valueClasses = fullValueTypeInfo.getClazzIterator(); Iterator<Class<?>> rpcClasses = fullRpcEndpointValueTypeInfo.getClazzIterator(); while (valueClasses.hasNext() && rpcClasses.hasNext()) { if (!checkType(valueClasses.next(), rpcClasses.next())) { return false; } } // both should be empty return !valueClasses.hasNext() && !rpcClasses.hasNext(); } } else { if (fullValueTypeInfo != null && !checkType(fullValueTypeInfo.getClazz(), endpointMethod.getReturnType())) { return false; } } } } return gatewayMethod.getName().equals(endpointMethod.getName()); } } private boolean checkType(Class<?> firstType, Class<?> secondType) { Class<?> firstResolvedType; Class<?> secondResolvedType; if (firstType.isPrimitive()) { firstResolvedType = RpcCompletenessTest.resolvePrimitiveType(firstType); } else { firstResolvedType = firstType; } if (secondType.isPrimitive()) { secondResolvedType = RpcCompletenessTest.resolvePrimitiveType(secondType); } else { secondResolvedType = secondType; } return firstResolvedType.equals(secondResolvedType); } /** * Generates from a gateway rpc method signature the corresponding rpc endpoint signature. * * For example the {@link RpcTimeout} annotation adds an additional parameter to the gateway * signature which is not relevant on the server side. * * @param method Method to generate the signature string for * @return String of the respective server side rpc method signature */ private String generateEndpointMethodSignature(Method method) { StringBuilder builder = new StringBuilder(); if (method.getReturnType().equals(Void.TYPE)) { builder.append("void").append(" "); } else if (method.getReturnType().equals(futureClass)) { ReflectionUtil.FullTypeInfo fullTypeInfo = ReflectionUtil.getFullTemplateType(method.getGenericReturnType(), 0); builder .append(futureClass.getSimpleName()) .append("<") .append(fullTypeInfo != null ? fullTypeInfo.toString() : "") .append(">"); if (fullTypeInfo != null) { builder.append("/").append(fullTypeInfo); } builder.append(" "); } else { return "Invalid rpc method signature."; } builder.append(method.getName()).append("("); Class<?>[] parameterTypes = method.getParameterTypes(); Annotation[][] parameterAnnotations = method.getParameterAnnotations(); assertEquals(parameterTypes.length, parameterAnnotations.length); for (int i = 0; i < parameterTypes.length; i++) { // filter out the RpcTimeout parameters if (!RpcCompletenessTest.isRpcTimeout(parameterAnnotations[i])) { builder.append(parameterTypes[i].getName()); if (i < parameterTypes.length -1) { builder.append(", "); } } } builder.append(")"); return builder.toString(); } private static boolean isRpcTimeout(Annotation[] annotations) { for (Annotation annotation : annotations) { if (annotation.annotationType().equals(RpcTimeout.class)) { return true; } } return false; } /** * Returns the boxed type for a primitive type. * * @param primitveType Primitive type to resolve * @return Boxed type for the given primitive type */ private static Class<?> resolvePrimitiveType(Class<?> primitveType) { assert primitveType.isPrimitive(); TypeInformation<?> typeInformation = BasicTypeInfo.getInfoFor(primitveType); if (typeInformation != null) { return typeInformation.getTypeClass(); } else { throw new RuntimeException("Could not retrive basic type information for primitive type " + primitveType + '.'); } } /** * Extract all rpc methods defined by the gateway interface * * @param interfaceClass the given rpc gateway interface * @return all methods defined by the given interface */ private List<Method> getRpcMethodsFromGateway(Class<? extends RpcGateway> interfaceClass) { if(!interfaceClass.isInterface()) { fail(interfaceClass.getName() + " is not a interface"); } ArrayList<Method> allMethods = new ArrayList<>(); // Methods defined in RpcGateway are native method if(interfaceClass.equals(RpcGateway.class)) { return allMethods; } // Get all methods declared in current interface Collections.addAll(allMethods, interfaceClass.getDeclaredMethods()); // Get all method inherited from super interface for (Class<?> superClass : interfaceClass.getInterfaces()) { @SuppressWarnings("unchecked") Class<? extends RpcGateway> gatewayClass = (Class<? extends RpcGateway>) superClass; allMethods.addAll(getRpcMethodsFromGateway(gatewayClass)); } return allMethods; } }