/* * Copyright 2016 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.internal.thrift; import static java.util.Objects.requireNonNull; import java.lang.reflect.Method; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Map.Entry; import org.apache.thrift.AsyncProcessFunction; import org.apache.thrift.ProcessFunction; import org.apache.thrift.TApplicationException; import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.apache.thrift.TFieldIdEnum; import org.apache.thrift.meta_data.FieldMetaData; import org.apache.thrift.protocol.TMessageType; import com.google.common.collect.ImmutableMap; /** * Provides the metadata of a Thrift service function. */ public final class ThriftFunction { private enum Type { SYNC, ASYNC } private final Object func; private final Type type; private final Class<?> serviceType; private final String name; private final TBase<?, ?> result; private final TFieldIdEnum[] argFields; private final TFieldIdEnum successField; private final Map<Class<Throwable>, TFieldIdEnum> exceptionFields; private final Class<?>[] declaredExceptions; ThriftFunction(Class<?> serviceType, ProcessFunction<?, ?> func) throws Exception { this(serviceType, func.getMethodName(), func, Type.SYNC, getArgFields(func), getResult(func), getDeclaredExceptions(func)); } ThriftFunction(Class<?> serviceType, AsyncProcessFunction<?, ?, ?> func) throws Exception { this(serviceType, func.getMethodName(), func, Type.ASYNC, getArgFields(func), getResult(func), getDeclaredExceptions(func)); } private ThriftFunction( Class<?> serviceType, String name, Object func, Type type, TFieldIdEnum[] argFields, TBase<?, ?> result, Class<?>[] declaredExceptions) throws Exception { this.func = func; this.type = type; this.serviceType = serviceType; this.name = name; this.argFields = argFields; this.result = result; this.declaredExceptions = declaredExceptions; // Determine the success and exception fields of the function. final ImmutableMap.Builder<Class<Throwable>, TFieldIdEnum> exceptionFieldsBuilder = ImmutableMap.builder(); TFieldIdEnum successField = null; if (result != null) { // if not oneway @SuppressWarnings("unchecked") final Class<? extends TBase<?, ?>> resultType = (Class<? extends TBase<?, ?>>) result.getClass(); @SuppressWarnings("unchecked") final Map<TFieldIdEnum, FieldMetaData> metaDataMap = (Map<TFieldIdEnum, FieldMetaData>) FieldMetaData.getStructMetaDataMap(resultType); for (Entry<TFieldIdEnum, FieldMetaData> e : metaDataMap.entrySet()) { final TFieldIdEnum key = e.getKey(); final String fieldName = key.getFieldName(); if ("success".equals(fieldName)) { successField = key; continue; } Class<?> fieldType = resultType.getField(fieldName).getType(); if (Throwable.class.isAssignableFrom(fieldType)) { @SuppressWarnings("unchecked") Class<Throwable> exceptionFieldType = (Class<Throwable>) fieldType; exceptionFieldsBuilder.put(exceptionFieldType, key); } } } this.successField = successField; exceptionFields = exceptionFieldsBuilder.build(); } /** * Returns {@code true} if this function is a one-way. */ public boolean isOneWay() { return result == null; } /** * Returns {@code true} if this function is asynchronous. */ public boolean isAsync() { return type == Type.ASYNC; } /** * Returns the type of this function. * * @return {@link TMessageType#CALL} or {@link TMessageType#ONEWAY} */ public byte messageType() { return isOneWay() ? TMessageType.ONEWAY : TMessageType.CALL; } /** * Returns the {@link ProcessFunction}. * * @throws ClassCastException if this function is asynchronous */ @SuppressWarnings("unchecked") public ProcessFunction<Object, TBase<?, ?>> syncFunc() { return (ProcessFunction<Object, TBase<?, ?>>) func; } /** * Returns the {@link AsyncProcessFunction}. * * @throws ClassCastException if this function is synchronous */ @SuppressWarnings("unchecked") public AsyncProcessFunction<Object, TBase<?, ?>, Object> asyncFunc() { return (AsyncProcessFunction<Object, TBase<?, ?>, Object>) func; } /** * Returns the Thrift service interface this function belongs to. */ public Class<?> serviceType() { return serviceType; } /** * Returns the name of this function. */ public String name() { return name; } /** * Returns the field that holds the successful result. */ public TFieldIdEnum successField() { return successField; } /** * Returns the field that holds the exception. */ public Collection<TFieldIdEnum> exceptionFields() { return exceptionFields.values(); } /** * Returns the exceptions declared by this function. */ public Class<?>[] declaredExceptions() { return declaredExceptions; } /** * Returns a new empty arguments instance. */ public TBase<?, ?> newArgs() { if (isAsync()) { return asyncFunc().getEmptyArgsInstance(); } else { return syncFunc().getEmptyArgsInstance(); } } /** * Returns a new arguments instance. */ public TBase<?, ?> newArgs(List<Object> args) { requireNonNull(args, "args"); final TBase<?, ?> newArgs = newArgs(); final int size = args.size(); for (int i = 0; i < size; i++) { ThriftFieldAccess.set(newArgs, argFields[i], args.get(i)); } return newArgs; } /** * Returns a new empty result instance. */ public TBase<?, ?> newResult() { return result.deepCopy(); } /** * Sets the success field of the specified {@code result} to the specified {@code value}. */ public void setSuccess(TBase<?, ?> result, Object value) { if (successField != null) { ThriftFieldAccess.set(result, successField, value); } } /** * Converts the specified {@code result} into a Java object. */ public Object getResult(TBase<?, ?> result) throws TException { for (TFieldIdEnum fieldIdEnum : exceptionFields()) { if (ThriftFieldAccess.isSet(result, fieldIdEnum)) { throw (TException) ThriftFieldAccess.get(result, fieldIdEnum); } } final TFieldIdEnum successField = successField(); if (successField == null) { //void method return null; } else if (ThriftFieldAccess.isSet(result, successField)) { return ThriftFieldAccess.get(result, successField); } else { throw new TApplicationException( TApplicationException.MISSING_RESULT, result.getClass().getName() + '.' + successField.getFieldName()); } } private static TBase<?, ?> getResult(ProcessFunction<?, ?> func) { return getResult0(Type.SYNC, func.getClass(), func.getMethodName()); } private static TBase<?, ?> getResult(AsyncProcessFunction<?, ?, ?> asyncFunc) { return getResult0(Type.ASYNC, asyncFunc.getClass(), asyncFunc.getMethodName()); } private static TBase<?, ?> getResult0(Type type, Class<?> funcClass, String methodName) { final String resultTypeName = typeName(type, funcClass, methodName, methodName + "_result"); try { @SuppressWarnings("unchecked") final Class<TBase<?, ?>> resultType = (Class<TBase<?, ?>>) Class.forName(resultTypeName, false, funcClass.getClassLoader()); return resultType.newInstance(); } catch (ClassNotFoundException ignored) { // Oneway function does not have a result type. return null; } catch (Exception e) { throw new IllegalStateException("cannot determine the result type of method: " + methodName, e); } } /** * Sets the exception field of the specified {@code result} to the specified {@code cause}. */ public boolean setException(TBase<?, ?> result, Throwable cause) { Class<?> causeType = cause.getClass(); for (Entry<Class<Throwable>, TFieldIdEnum> e : exceptionFields.entrySet()) { if (e.getKey().isAssignableFrom(causeType)) { ThriftFieldAccess.set(result, e.getValue(), cause); return true; } } return false; } private static TBase<?, ?> getArgs(ProcessFunction<?, ?> func) { return getArgs0(Type.SYNC, func.getClass(), func.getMethodName()); } private static TBase<?, ?> getArgs(AsyncProcessFunction<?, ?, ?> asyncFunc) { return getArgs0(Type.ASYNC, asyncFunc.getClass(), asyncFunc.getMethodName()); } private static TBase<?, ?> getArgs0(Type type, Class<?> funcClass, String methodName) { final String argsTypeName = typeName(type, funcClass, methodName, methodName + "_args"); try { @SuppressWarnings("unchecked") final Class<TBase<?, ?>> argsType = (Class<TBase<?, ?>>) Class.forName(argsTypeName, false, funcClass.getClassLoader()); return argsType.newInstance(); } catch (Exception e) { throw new IllegalStateException("cannot determine the args class of method: " + methodName, e); } } private static TFieldIdEnum[] getArgFields(ProcessFunction<?, ?> func) { return getArgFields0(Type.SYNC, func.getClass(), func.getMethodName()); } private static TFieldIdEnum[] getArgFields(AsyncProcessFunction<?, ?, ?> asyncFunc) { return getArgFields0(Type.ASYNC, asyncFunc.getClass(), asyncFunc.getMethodName()); } private static TFieldIdEnum[] getArgFields0(Type type, Class<?> funcClass, String methodName) { final String fieldIdEnumTypeName = typeName(type, funcClass, methodName, methodName + "_args$_Fields"); try { Class<?> fieldIdEnumType = Class.forName(fieldIdEnumTypeName, false, funcClass.getClassLoader()); return (TFieldIdEnum[]) requireNonNull(fieldIdEnumType.getEnumConstants(), "field enum may not be empty"); } catch (Exception e) { throw new IllegalStateException("cannot determine the arg fields of method: " + methodName, e); } } private static Class<?>[] getDeclaredExceptions(ProcessFunction<?, ?> func) { return getDeclaredExceptions0(Type.SYNC, func.getClass(), func.getMethodName()); } private static Class<?>[] getDeclaredExceptions(AsyncProcessFunction<?, ?, ?> asyncFunc) { return getDeclaredExceptions0(Type.ASYNC, asyncFunc.getClass(), asyncFunc.getMethodName()); } private static Class<?>[] getDeclaredExceptions0( Type type, Class<?> funcClass, String methodName) { final String ifaceTypeName = typeName(type, funcClass, methodName, "Iface"); try { Class<?> ifaceType = Class.forName(ifaceTypeName, false, funcClass.getClassLoader()); for (Method m : ifaceType.getDeclaredMethods()) { if (!m.getName().equals(methodName)) { continue; } return m.getExceptionTypes(); } throw new IllegalStateException("failed to find a method: " + methodName); } catch (Exception e) { throw new IllegalStateException( "cannot determine the declared exceptions of method: " + methodName, e); } } private static String typeName(Type type, Class<?> funcClass, String methodName, String toAppend) { final String funcClassName = funcClass.getName(); final int serviceClassEndPos = funcClassName.lastIndexOf( (type == Type.SYNC ? "$Processor$" : "$AsyncProcessor$") + methodName); if (serviceClassEndPos <= 0) { throw new IllegalStateException("cannot determine the service class of method: " + methodName); } return funcClassName.substring(0, serviceClassEndPos) + '$' + toAppend; } }