/*
* Copyright 2017 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.server.grpc;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.EnumDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.Descriptors.ServiceDescriptor;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.docs.DocServicePlugin;
import com.linecorp.armeria.server.docs.EndpointInfo;
import com.linecorp.armeria.server.docs.EnumInfo;
import com.linecorp.armeria.server.docs.EnumValueInfo;
import com.linecorp.armeria.server.docs.FieldInfo;
import com.linecorp.armeria.server.docs.FieldRequirement;
import com.linecorp.armeria.server.docs.MethodInfo;
import com.linecorp.armeria.server.docs.NamedTypeInfo;
import com.linecorp.armeria.server.docs.ServiceInfo;
import com.linecorp.armeria.server.docs.ServiceSpecification;
import com.linecorp.armeria.server.docs.StructInfo;
import com.linecorp.armeria.server.docs.TypeSignature;
import io.grpc.ServerServiceDefinition;
import io.grpc.protobuf.ProtoFileDescriptorSupplier;
/**
* {@link DocServicePlugin} implementation that supports {@link GrpcService}s.
*/
public class GrpcDocServicePlugin implements DocServicePlugin {
@VisibleForTesting
static final TypeSignature BOOL = TypeSignature.ofBase("bool");
@VisibleForTesting
static final TypeSignature INT32 = TypeSignature.ofBase("int32");
@VisibleForTesting
static final TypeSignature INT64 = TypeSignature.ofBase("int64");
@VisibleForTesting
static final TypeSignature UINT32 = TypeSignature.ofBase("uint32");
@VisibleForTesting
static final TypeSignature UINT64 = TypeSignature.ofBase("uint64");
@VisibleForTesting
static final TypeSignature SINT32 = TypeSignature.ofBase("sint32");
@VisibleForTesting
static final TypeSignature SINT64 = TypeSignature.ofBase("sint64");
@VisibleForTesting
static final TypeSignature FLOAT = TypeSignature.ofBase("float");
@VisibleForTesting
static final TypeSignature DOUBLE = TypeSignature.ofBase("double");
@VisibleForTesting
static final TypeSignature FIXED32 = TypeSignature.ofBase("fixed32");
@VisibleForTesting
static final TypeSignature FIXED64 = TypeSignature.ofBase("fixed64");
@VisibleForTesting
static final TypeSignature SFIXED32 = TypeSignature.ofBase("sfixed32");
@VisibleForTesting
static final TypeSignature SFIXED64 = TypeSignature.ofBase("sfixed64");
@VisibleForTesting
static final TypeSignature STRING = TypeSignature.ofBase("string");
@VisibleForTesting
static final TypeSignature BYTES = TypeSignature.ofBase("bytes");
@VisibleForTesting
static final TypeSignature UNKNOWN = TypeSignature.ofBase("unknown");
private final GrpcDocStringExtractor docstringExtractor = new GrpcDocStringExtractor();
@Override
public Set<Class<? extends Service<?, ?>>> supportedServiceTypes() {
return ImmutableSet.of(GrpcService.class);
}
@Override
public ServiceSpecification generateSpecification(Set<ServiceConfig> serviceConfigs) {
final Map<String, ServiceEntryBuilder> map = new LinkedHashMap<>();
for (ServiceConfig serviceConfig : serviceConfigs) {
final GrpcService grpcService = serviceConfig.service().as(GrpcService.class).get();
for (ServerServiceDefinition service : grpcService.services()) {
map.computeIfAbsent(
service.getServiceDescriptor().getName(),
s -> {
FileDescriptor fileDescriptor =
((ProtoFileDescriptorSupplier)
service.getServiceDescriptor().getSchemaDescriptor())
.getFileDescriptor();
ServiceDescriptor serviceDescriptor = fileDescriptor.getServices().stream()
.filter(sd -> sd.getFullName().equals(
service.getServiceDescriptor().getName()))
.findFirst()
.orElseThrow(IllegalStateException::new);
return new ServiceEntryBuilder(serviceDescriptor);
});
}
serviceConfig.pathMapping().prefix().ifPresent(
path -> {
for (ServerServiceDefinition service : grpcService.services()) {
final String serviceName = service.getServiceDescriptor().getName();
map.get(serviceName).endpoint(
new EndpointInfo(
serviceConfig.virtualHost().hostnamePattern(),
// TODO(anuraag): Move EndpointInfo from service to function,
// which is where we display it.
path + serviceName + "/*",
"",
GrpcSerializationFormats.PROTO,
ImmutableList.of(GrpcSerializationFormats.PROTO)));
}
});
}
return generate(map.values().stream()
.map(ServiceEntryBuilder::build)
.collect(toImmutableList()));
}
@Override
public Map<String, String> loadDocStrings(Set<ServiceConfig> serviceConfigs) {
return serviceConfigs.stream()
.flatMap(c -> c.service().as(GrpcService.class).get().services().stream())
.flatMap(s -> docstringExtractor.getAllDocStrings(s.getClass().getClassLoader())
.entrySet().stream())
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue, (a, b) -> a));
}
@VisibleForTesting
ServiceSpecification generate(List<ServiceEntry> entries) {
final List<ServiceInfo> services = entries.stream()
.map(this::newServiceInfo)
.collect(toImmutableList());
return ServiceSpecification.generate(services, this::newNamedTypeInfo);
}
private NamedTypeInfo newNamedTypeInfo(TypeSignature typeSignature) {
Object descriptor = typeSignature.namedTypeDescriptor().get();
if (descriptor instanceof Descriptor) {
return newStructInfo((Descriptor) descriptor);
}
assert descriptor instanceof EnumDescriptor;
return newEnumInfo((EnumDescriptor) descriptor);
}
ServiceInfo newServiceInfo(ServiceEntry entry) {
final List<MethodInfo> functions = entry.methods().stream()
.map(this::newMethodInfo)
.collect(toImmutableList());
return new ServiceInfo(
entry.name(),
functions,
entry.endpointInfos);
}
@VisibleForTesting
MethodInfo newMethodInfo(MethodDescriptor method) {
return new MethodInfo(
method.getName(),
namedMessageSignature(method.getOutputType()),
// GRPC methods always take a single request parameter of message type.
ImmutableList.of(
new FieldInfo(
"request",
FieldRequirement.REQUIRED,
namedMessageSignature(method.getInputType()))),
ImmutableList.of());
}
@VisibleForTesting
StructInfo newStructInfo(Descriptor descriptor) {
return new StructInfo(
descriptor.getFullName(),
descriptor.getFields().stream()
.map(this::newFieldInfo)
.collect(toImmutableList()));
}
private FieldInfo newFieldInfo(FieldDescriptor fieldDescriptor) {
return new FieldInfo(
fieldDescriptor.getName(),
fieldDescriptor.isRequired() ? FieldRequirement.REQUIRED : FieldRequirement.OPTIONAL,
newFieldTypeInfo(fieldDescriptor));
}
@VisibleForTesting
TypeSignature newFieldTypeInfo(FieldDescriptor fieldDescriptor) {
if (fieldDescriptor.isMapField()) {
return TypeSignature.ofMap(
newFieldTypeInfo(fieldDescriptor.getMessageType().findFieldByNumber(1)),
newFieldTypeInfo(fieldDescriptor.getMessageType().findFieldByNumber(2)));
}
final TypeSignature fieldType;
switch (fieldDescriptor.getType()) {
case BOOL:
fieldType = BOOL;
break;
case BYTES:
fieldType = BYTES;
break;
case DOUBLE:
fieldType = DOUBLE;
break;
case FIXED32:
fieldType = FIXED32;
break;
case FIXED64:
fieldType = FIXED64;
break;
case FLOAT:
fieldType = FLOAT;
break;
case INT32:
fieldType = INT32;
break;
case INT64:
fieldType = INT64;
break;
case SFIXED32:
fieldType = SFIXED32;
break;
case SFIXED64:
fieldType = SFIXED64;
break;
case SINT32:
fieldType = SINT32;
break;
case SINT64:
fieldType = SINT64;
break;
case STRING:
fieldType = STRING;
break;
case UINT32:
fieldType = UINT32;
break;
case UINT64:
fieldType = UINT64;
break;
case MESSAGE:
fieldType = namedMessageSignature(fieldDescriptor.getMessageType());
break;
case GROUP:
// This type has been deprecated since the launch of protocol buffers to open source.
// There is no real metadata for this in the descriptor so we just treat as UNKNOWN
// since it shouldn't happen in practice anyways.
fieldType = UNKNOWN;
break;
case ENUM:
fieldType = TypeSignature.ofNamed(
fieldDescriptor.getEnumType().getFullName(), fieldDescriptor.getEnumType());
break;
default:
fieldType = UNKNOWN;
break;
}
return fieldDescriptor.isRepeated() ? TypeSignature.ofContainer("repeated", fieldType) : fieldType;
}
@VisibleForTesting
EnumInfo newEnumInfo(EnumDescriptor enumDescriptor) {
return new EnumInfo(
enumDescriptor.getFullName(),
enumDescriptor.getValues().stream()
.map(d -> new EnumValueInfo(d.getName()))
.collect(toImmutableList()));
}
private TypeSignature namedMessageSignature(Descriptor descriptor) {
return TypeSignature.ofNamed(descriptor.getFullName(), descriptor);
}
@VisibleForTesting
static final class ServiceEntry {
final ServiceDescriptor service;
final List<EndpointInfo> endpointInfos;
ServiceEntry(ServiceDescriptor service,
List<EndpointInfo> endpointInfos) {
this.service = service;
this.endpointInfos = ImmutableList.copyOf(endpointInfos);
}
String name() {
return service.getFullName();
}
List<MethodDescriptor> methods() {
return ImmutableList.copyOf(service.getMethods());
}
}
@VisibleForTesting
static final class ServiceEntryBuilder {
private final ServiceDescriptor service;
private final List<EndpointInfo> endpointInfos = new ArrayList<>();
ServiceEntryBuilder(ServiceDescriptor service) {
this.service = service;
}
ServiceEntryBuilder endpoint(EndpointInfo endpointInfo) {
endpointInfos.add(requireNonNull(endpointInfo, "endpointInfo"));
return this;
}
ServiceEntry build() {
return new ServiceEntry(service, endpointInfos);
}
}
}