/*
* Copyright 2017 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.server.thrift;
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.TFieldRequirementType;
import org.apache.thrift.TSerializer;
import org.apache.thrift.meta_data.EnumMetaData;
import org.apache.thrift.meta_data.FieldMetaData;
import org.apache.thrift.meta_data.FieldValueMetaData;
import org.apache.thrift.meta_data.ListMetaData;
import org.apache.thrift.meta_data.MapMetaData;
import org.apache.thrift.meta_data.SetMetaData;
import org.apache.thrift.meta_data.StructMetaData;
import org.apache.thrift.protocol.TType;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.docs.DocServicePlugin;
import com.linecorp.armeria.server.docs.EndpointInfo;
import com.linecorp.armeria.server.docs.EnumInfo;
import com.linecorp.armeria.server.docs.EnumValueInfo;
import com.linecorp.armeria.server.docs.ExceptionInfo;
import com.linecorp.armeria.server.docs.FieldInfo;
import com.linecorp.armeria.server.docs.FieldRequirement;
import com.linecorp.armeria.server.docs.MethodInfo;
import com.linecorp.armeria.server.docs.NamedTypeInfo;
import com.linecorp.armeria.server.docs.ServiceInfo;
import com.linecorp.armeria.server.docs.ServiceSpecification;
import com.linecorp.armeria.server.docs.StructInfo;
import com.linecorp.armeria.server.docs.TypeSignature;
/**
* {@link DocServicePlugin} implementation that supports {@link THttpService}s.
*/
public class ThriftDocServicePlugin implements DocServicePlugin {
private static final String REQUEST_STRUCT_SUFFIX = "_args";
private static final TypeSignature VOID = TypeSignature.ofBase("void");
private static final TypeSignature BOOL = TypeSignature.ofBase("bool");
private static final TypeSignature I8 = TypeSignature.ofBase("i8");
private static final TypeSignature I16 = TypeSignature.ofBase("i16");
private static final TypeSignature I32 = TypeSignature.ofBase("i32");
private static final TypeSignature I64 = TypeSignature.ofBase("i64");
private static final TypeSignature DOUBLE = TypeSignature.ofBase("double");
private static final TypeSignature STRING = TypeSignature.ofBase("string");
private static final TypeSignature BINARY = TypeSignature.ofBase("binary");
private ThriftDocStringExtractor docstringExtractor = new ThriftDocStringExtractor();
// Methods related with generating a service specification.
@Override
public Set<Class<? extends Service<?, ?>>> supportedServiceTypes() {
return ImmutableSet.of(THttpService.class);
}
@Override
public ServiceSpecification generateSpecification(Set<ServiceConfig> serviceConfigs) {
final Map<Class<?>, EntryBuilder> map = new LinkedHashMap<>();
for (ServiceConfig c : serviceConfigs) {
final THttpService service = c.service().as(THttpService.class).get();
service.entries().forEach((serviceName, entry) -> {
for (Class<?> iface : entry.interfaces()) {
final Class<?> serviceClass = iface.getEnclosingClass();
final EntryBuilder builder =
map.computeIfAbsent(serviceClass, cls -> new EntryBuilder(serviceClass));
// Add all available endpoints.
c.pathMapping().exactPath().ifPresent(
p -> builder.endpoint(new EndpointInfo(
c.virtualHost().hostnamePattern(),
p, serviceName,
service.defaultSerializationFormat(),
service.allowedSerializationFormats())));
}
});
}
final List<Entry> entries = map.values().stream()
.map(EntryBuilder::build)
.collect(Collectors.toList());
return generate(entries);
}
@VisibleForTesting
static ServiceSpecification generate(List<Entry> entries) {
final List<ServiceInfo> services = entries.stream()
.map(e -> newServiceInfo(e.serviceType, e.endpointInfos))
.collect(toImmutableList());
return ServiceSpecification.generate(services, ThriftDocServicePlugin::newNamedTypeInfo);
}
@VisibleForTesting
static ServiceInfo newServiceInfo(Class<?> serviceClass, Iterable<EndpointInfo> endpoints) {
requireNonNull(serviceClass, "serviceClass");
final String name = serviceClass.getName();
final ClassLoader serviceClassLoader = serviceClass.getClassLoader();
final String interfaceClassName = name + "$Iface";
final Class<?> interfaceClass;
try {
interfaceClass = Class.forName(interfaceClassName, false, serviceClassLoader);
} catch (ClassNotFoundException e) {
throw new IllegalStateException("failed to find a class: " + interfaceClassName, e);
}
final Method[] methods = interfaceClass.getDeclaredMethods();
return new ServiceInfo(name,
Arrays.stream(methods).map(ThriftDocServicePlugin::newMethodInfo)::iterator,
endpoints);
}
private static MethodInfo newMethodInfo(Method method) {
requireNonNull(method, "method");
final String methodName = method.getName();
final Class<?> serviceClass = method.getDeclaringClass().getDeclaringClass();
final String serviceName = serviceClass.getName();
final ClassLoader classLoader = serviceClass.getClassLoader();
final String argsClassName = serviceName + '$' + methodName + "_args";
final Class<? extends TBase<?, ?>> argsClass;
try {
@SuppressWarnings("unchecked")
final Class<? extends TBase<?, ?>> argsClass0 =
(Class<? extends TBase<?, ?>>) Class.forName(argsClassName, false, classLoader);
argsClass = argsClass0;
} catch (ClassNotFoundException e) {
throw new IllegalStateException("failed to find a class: " + argsClassName, e);
}
Class<?> resultClass;
try {
resultClass = Class.forName(serviceName + '$' + methodName + "_result", false, classLoader);
} catch (ClassNotFoundException ignored) {
// Oneway function does not have a result type.
resultClass = null;
}
@SuppressWarnings("unchecked")
final MethodInfo methodInfo =
newMethodInfo(methodName, argsClass,
(Class<? extends TBase<?, ?>>) resultClass,
(Class<? extends TException>[]) method.getExceptionTypes());
return methodInfo;
}
private static MethodInfo newMethodInfo(String name,
Class<? extends TBase<?, ?>> argsClass,
@Nullable Class<? extends TBase<?, ?>> resultClass,
Class<? extends TException>[] exceptionClasses) {
requireNonNull(name, "name");
requireNonNull(argsClass, "argsClass");
requireNonNull(exceptionClasses, "exceptionClasses");
final List<FieldInfo> parameters =
FieldMetaData.getStructMetaDataMap(argsClass).values().stream()
.map(fieldMetaData -> newFieldInfo(argsClass, fieldMetaData))
.collect(toImmutableList());
// Find the 'success' field.
FieldInfo fieldInfo = null;
if (resultClass != null) { // Function isn't "oneway" function
final Map<? extends TFieldIdEnum, FieldMetaData> resultMetaData =
FieldMetaData.getStructMetaDataMap(resultClass);
for (FieldMetaData fieldMetaData : resultMetaData.values()) {
if ("success".equals(fieldMetaData.fieldName)) {
fieldInfo = newFieldInfo(resultClass, fieldMetaData);
break;
}
}
}
final TypeSignature returnTypeSignature;
if (fieldInfo == null) {
returnTypeSignature = VOID;
} else {
returnTypeSignature = fieldInfo.typeSignature();
}
final List<TypeSignature> exceptionTypeSignatures =
Arrays.stream(exceptionClasses)
.filter(e -> e != TException.class)
.map(TypeSignature::ofNamed)
.collect(toImmutableList());
return new MethodInfo(name, returnTypeSignature, parameters, exceptionTypeSignatures);
}
private static NamedTypeInfo newNamedTypeInfo(TypeSignature typeSignature) {
Class<?> type = (Class<?>) typeSignature.namedTypeDescriptor().get();
if (type.isEnum()) {
return newEnumInfo(type);
}
if (TException.class.isAssignableFrom(type)) {
@SuppressWarnings("unchecked")
final Class<? extends TException> castType = (Class<? extends TException>) type;
return newExceptionInfo(castType);
}
assert TBase.class.isAssignableFrom(type);
@SuppressWarnings("unchecked")
final Class<? extends TBase<?, ?>> castType = (Class<? extends TBase<?, ?>>) type;
return newStructInfo(castType);
}
@VisibleForTesting
static EnumInfo newEnumInfo(Class<?> enumClass) {
requireNonNull(enumClass, "enumClass");
final List<EnumValueInfo> values = new ArrayList<>();
final Field[] fields = enumClass.getDeclaredFields();
for (Field field : fields) {
if (field.isEnumConstant()) {
try {
values.add(new EnumValueInfo(String.valueOf(field.get(null))));
} catch (IllegalAccessException ignored) {
// Skip inaccessible fields.
}
}
}
final String name = enumClass.getName();
return new EnumInfo(name, values);
}
@VisibleForTesting
static StructInfo newStructInfo(Class<? extends TBase<?, ?>> structClass) {
final String name = structClass.getName();
final Map<?, FieldMetaData> metaDataMap = FieldMetaData.getStructMetaDataMap(structClass);
final List<FieldInfo> fields =
metaDataMap.values().stream()
.map(fieldMetaData -> newFieldInfo(structClass, fieldMetaData))
.collect(Collectors.toList());
return new StructInfo(name, fields);
}
@VisibleForTesting
static ExceptionInfo newExceptionInfo(Class<? extends TException> exceptionClass) {
requireNonNull(exceptionClass, "exceptionClass");
final String name = exceptionClass.getName();
List<FieldInfo> fields;
try {
@SuppressWarnings("unchecked")
final Map<?, FieldMetaData> metaDataMap =
(Map<?, FieldMetaData>) exceptionClass.getDeclaredField("metaDataMap").get(null);
fields = metaDataMap.values().stream()
.map(fieldMetaData -> newFieldInfo(exceptionClass, fieldMetaData))
.collect(toImmutableList());
} catch (IllegalAccessException e) {
throw new AssertionError("will not happen", e);
} catch (NoSuchFieldException ignored) {
fields = Collections.emptyList();
}
return new ExceptionInfo(name, fields);
}
@VisibleForTesting
static FieldInfo newFieldInfo(Class<?> parentType, FieldMetaData fieldMetaData) {
requireNonNull(fieldMetaData, "fieldMetaData");
final FieldValueMetaData fieldValueMetaData = fieldMetaData.valueMetaData;
final TypeSignature typeSignature;
if (fieldValueMetaData.isStruct() && fieldValueMetaData.isTypedef() &&
parentType.getSimpleName().equals(fieldValueMetaData.getTypedefName())) {
// Handle the special case where a struct field refers to itself,
// where the Thrift compiler handles it as a typedef.
typeSignature = TypeSignature.ofNamed(parentType);
} else {
typeSignature = toTypeSignature(fieldValueMetaData);
}
return new FieldInfo(fieldMetaData.fieldName,
convertRequirement(fieldMetaData.requirementType),
typeSignature);
}
@VisibleForTesting
static TypeSignature toTypeSignature(FieldValueMetaData fieldValueMetaData) {
if (fieldValueMetaData instanceof StructMetaData) {
return TypeSignature.ofNamed(((StructMetaData) fieldValueMetaData).structClass);
}
if (fieldValueMetaData instanceof EnumMetaData) {
return TypeSignature.ofNamed(((EnumMetaData) fieldValueMetaData).enumClass);
}
if (fieldValueMetaData instanceof ListMetaData) {
return TypeSignature.ofList(toTypeSignature(((ListMetaData) fieldValueMetaData).elemMetaData));
}
if (fieldValueMetaData instanceof SetMetaData) {
return TypeSignature.ofSet(toTypeSignature(((SetMetaData) fieldValueMetaData).elemMetaData));
}
if (fieldValueMetaData instanceof MapMetaData) {
return TypeSignature.ofMap(toTypeSignature(((MapMetaData) fieldValueMetaData).keyMetaData),
toTypeSignature(((MapMetaData) fieldValueMetaData).valueMetaData));
}
if (fieldValueMetaData.isBinary()) {
return BINARY;
}
switch (fieldValueMetaData.type) {
case TType.VOID:
return VOID;
case TType.BOOL:
return BOOL;
case TType.BYTE:
return I8;
case TType.DOUBLE:
return DOUBLE;
case TType.I16:
return I16;
case TType.I32:
return I32;
case TType.I64:
return I64;
case TType.STRING:
return STRING;
}
final String unresolvedName;
if (fieldValueMetaData.isTypedef()) {
unresolvedName = fieldValueMetaData.getTypedefName();
} else {
unresolvedName = null;
}
return TypeSignature.ofUnresolved(firstNonNull(unresolvedName, "unknown"));
}
private static FieldRequirement convertRequirement(byte value) {
switch (value) {
case TFieldRequirementType.REQUIRED:
return FieldRequirement.REQUIRED;
case TFieldRequirementType.OPTIONAL:
return FieldRequirement.OPTIONAL;
case TFieldRequirementType.DEFAULT:
return FieldRequirement.DEFAULT;
default:
throw new IllegalArgumentException("unknown requirement type: " + value);
}
}
@VisibleForTesting
static final class Entry {
final Class<?> serviceType;
final List<EndpointInfo> endpointInfos;
Entry(Class<?> serviceType, List<EndpointInfo> endpointInfos) {
this.serviceType = serviceType;
this.endpointInfos = ImmutableList.copyOf(endpointInfos);
}
}
@VisibleForTesting
static final class EntryBuilder {
private final Class<?> serviceType;
private final List<EndpointInfo> endpointInfos = new ArrayList<>();
EntryBuilder(Class<?> serviceType) {
this.serviceType = requireNonNull(serviceType, "serviceType");
}
EntryBuilder endpoint(EndpointInfo endpointInfo) {
endpointInfos.add(requireNonNull(endpointInfo, "endpointInfo"));
return this;
}
Entry build() {
return new Entry(serviceType, endpointInfos);
}
}
// Methods related with extracting documentation strings.
@Override
public Map<String, String> loadDocStrings(Set<ServiceConfig> serviceConfigs) {
return serviceConfigs.stream()
.flatMap(c -> c.service().as(THttpService.class).get().entries().values().stream())
.flatMap(entry -> entry.interfaces().stream().map(Class::getClassLoader))
.flatMap(loader -> docstringExtractor.getAllDocStrings(loader)
.entrySet().stream())
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> a));
}
// Methods related with serializing example requests.
@Override
public Set<Class<?>> supportedExampleRequestTypes() {
return ImmutableSet.of(TBase.class);
}
@Override
public Optional<String> guessServiceName(Object exampleRequest) {
final TBase<?, ?> exampleTBase = asTBase(exampleRequest);
if (exampleTBase == null) {
return Optional.empty();
}
return Optional.of(exampleTBase.getClass().getEnclosingClass().getName());
}
@Override
public Optional<String> guessServiceMethodName(Object exampleRequest) {
final TBase<?, ?> exampleTBase = asTBase(exampleRequest);
if (exampleTBase == null) {
return Optional.empty();
}
final String typeName = exampleTBase.getClass().getName();
return Optional.of(typeName.substring(typeName.lastIndexOf('$') + 1,
typeName.length() - REQUEST_STRUCT_SUFFIX.length()));
}
@Override
public Optional<String> serializeExampleRequest(String serviceName, String methodName,
Object exampleRequest) {
if (!(exampleRequest instanceof TBase)) {
return Optional.empty();
}
final TBase<?, ?> exampleTBase = (TBase<?, ?>) exampleRequest;
final TSerializer serializer = new TSerializer(ThriftProtocolFactories.TEXT);
try {
return Optional.of(serializer.toString(exampleTBase, StandardCharsets.UTF_8.name()));
} catch (TException e) {
throw new Error("should never reach here", e);
}
}
private static TBase<?, ?> asTBase(Object exampleRequest) {
final TBase<?, ?> exampleTBase = (TBase<?, ?>) exampleRequest;
final Class<?> type = exampleTBase.getClass();
if (!type.getName().endsWith(REQUEST_STRUCT_SUFFIX)) {
return null;
}
final Class<?> serviceType = type.getEnclosingClass();
if (serviceType == null) {
return null;
}
if (serviceType.getEnclosingClass() != null) {
return null;
}
return exampleTBase;
}
}