/*
* Copyright 2017 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.server.grpc;
import static com.linecorp.armeria.common.http.HttpSessionProtocols.HTTP;
import static com.linecorp.armeria.internal.grpc.GrpcTestUtil.REQUEST_MESSAGE;
import static com.linecorp.armeria.internal.grpc.GrpcTestUtil.RESPONSE_MESSAGE;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.catchThrowable;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.DisableOnDebug;
import org.junit.rules.TestRule;
import org.junit.rules.Timeout;
import com.google.common.base.Strings;
import com.google.common.primitives.Bytes;
import com.google.common.primitives.Ints;
import com.google.protobuf.ByteString;
import com.google.protobuf.util.JsonFormat;
import com.linecorp.armeria.client.http.HttpClient;
import com.linecorp.armeria.client.http.HttpClientFactory;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.common.http.AggregatedHttpMessage;
import com.linecorp.armeria.common.http.HttpHeaderNames;
import com.linecorp.armeria.common.http.HttpHeaders;
import com.linecorp.armeria.common.http.HttpMethod;
import com.linecorp.armeria.common.http.HttpStatus;
import com.linecorp.armeria.grpc.testing.Messages.EchoStatus;
import com.linecorp.armeria.grpc.testing.Messages.Payload;
import com.linecorp.armeria.grpc.testing.Messages.SimpleRequest;
import com.linecorp.armeria.grpc.testing.Messages.SimpleResponse;
import com.linecorp.armeria.grpc.testing.Messages.StreamingOutputCallRequest;
import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc;
import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceBlockingStub;
import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceImplBase;
import com.linecorp.armeria.grpc.testing.UnitTestServiceGrpc.UnitTestServiceStub;
import com.linecorp.armeria.internal.grpc.GrpcHeaderNames;
import com.linecorp.armeria.internal.grpc.GrpcTestUtil;
import com.linecorp.armeria.internal.grpc.StreamRecorder;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.testing.server.ServerRule;
import io.grpc.Codec;
import io.grpc.DecompressorRegistry;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.util.AsciiString;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
public class GrpcServiceServerTest {
private static final int MAX_MESSAGE_SIZE = 16 * 1024 * 1024;
private static final AsciiString LARGE_PAYLOAD = AsciiString.of(Strings.repeat("a", MAX_MESSAGE_SIZE + 1));
private static class UnitTestServiceImpl extends UnitTestServiceImplBase {
private static final AttributeKey<Integer> CHECK_REQUEST_CONTEXT_COUNT =
AttributeKey.valueOf(UnitTestServiceImpl.class, "CHECK_REQUEST_CONTEXT_COUNT");
@Override
public void staticUnaryCall(SimpleRequest request, StreamObserver<SimpleResponse> responseObserver) {
if (!request.equals(REQUEST_MESSAGE)) {
responseObserver.onError(new IllegalArgumentException("Unexpected request: " + request));
return;
}
responseObserver.onNext(RESPONSE_MESSAGE);
responseObserver.onCompleted();
}
@Override
public void staticStreamedOutputCall(SimpleRequest request,
StreamObserver<SimpleResponse> responseObserver) {
if (!request.equals(REQUEST_MESSAGE)) {
responseObserver.onError(new IllegalArgumentException("Unexpected request: " + request));
return;
}
responseObserver.onNext(RESPONSE_MESSAGE);
responseObserver.onNext(RESPONSE_MESSAGE);
responseObserver.onCompleted();
}
@Override
public void errorNoMessage(SimpleRequest request, StreamObserver<SimpleResponse> responseObserver) {
responseObserver.onError(Status.ABORTED.asException());
}
@Override
public void errorWithMessage(SimpleRequest request, StreamObserver<SimpleResponse> responseObserver) {
responseObserver.onError(Status.ABORTED.withDescription("aborted call").asException());
}
@Override
public void unaryThrowsError(SimpleRequest request, StreamObserver<SimpleResponse> responseObserver) {
throw Status.ABORTED.withDescription("call aborted").asRuntimeException();
}
@Override
public StreamObserver<SimpleRequest> streamThrowsError(
StreamObserver<SimpleResponse> responseObserver) {
return new StreamObserver<SimpleRequest>() {
@Override
public void onNext(SimpleRequest value) {
throw Status.ABORTED.withDescription("bad streaming message").asRuntimeException();
}
@Override
public void onError(Throwable t) {}
@Override
public void onCompleted() {}
};
}
@Override
public StreamObserver<SimpleRequest> streamThrowsErrorInStub(
StreamObserver<SimpleResponse> responseObserver) {
throw Status.ABORTED.withDescription("bad streaming stub").asRuntimeException();
}
@Override
public void staticUnaryCallSetsMessageCompression(SimpleRequest request,
StreamObserver<SimpleResponse> responseObserver) {
if (!request.equals(REQUEST_MESSAGE)) {
responseObserver.onError(new IllegalArgumentException("Unexpected request: " + request));
return;
}
ServerCallStreamObserver<SimpleResponse> callObserver =
(ServerCallStreamObserver<SimpleResponse>) responseObserver;
callObserver.setCompression("gzip");
callObserver.setMessageCompression(true);
responseObserver.onNext(RESPONSE_MESSAGE);
responseObserver.onCompleted();
}
@Override
public StreamObserver<SimpleRequest> checkRequestContext(
StreamObserver<SimpleResponse> responseObserver) {
RequestContext ctx = RequestContext.current();
ctx.attr(CHECK_REQUEST_CONTEXT_COUNT).set(0);
return new StreamObserver<SimpleRequest>() {
@Override
public void onNext(SimpleRequest value) {
RequestContext ctx = RequestContext.current();
Attribute<Integer> attr = ctx.attr(CHECK_REQUEST_CONTEXT_COUNT);
attr.set(attr.get() + 1);
}
@Override
public void onError(Throwable t) {}
@Override
public void onCompleted() {
RequestContext ctx = RequestContext.current();
int count = ctx.attr(CHECK_REQUEST_CONTEXT_COUNT).get();
responseObserver.onNext(
SimpleResponse.newBuilder()
.setPayload(
Payload.newBuilder()
.setBody(
ByteString.copyFromUtf8(
Integer.toString(count))))
.build());
responseObserver.onCompleted();
}
};
}
}
@ClassRule
public static ServerRule server = new ServerRule() {
@Override
protected void configure(ServerBuilder sb) throws Exception {
sb.numWorkers(1);
sb.port(0, HTTP);
sb.defaultMaxRequestLength(0);
sb.serviceUnder("/", new GrpcServiceBuilder()
.setMaxInboundMessageSizeBytes(MAX_MESSAGE_SIZE)
.addService(new UnitTestServiceImpl())
.enableUnframedRequests(true)
.supportedSerializationFormats(GrpcSerializationFormats.values())
.build());
}
};
@Rule
public TestRule globalTimeout = new DisableOnDebug(new Timeout(10, TimeUnit.SECONDS));
private static ManagedChannel channel;
private UnitTestServiceBlockingStub blockingClient;
private UnitTestServiceStub streamingClient;
@BeforeClass
public static void setUpChannel() {
channel = ManagedChannelBuilder.forAddress("127.0.0.1", server.httpPort())
.usePlaintext(true)
.build();
}
@AfterClass
public static void tearDownChannel() {
channel.shutdownNow();
}
@Before
public void setUp() {
blockingClient = UnitTestServiceGrpc.newBlockingStub(channel);
streamingClient = UnitTestServiceGrpc.newStub(channel);
}
@Test
public void unary_normal() throws Exception {
assertThat(blockingClient.staticUnaryCall(REQUEST_MESSAGE)).isEqualTo(RESPONSE_MESSAGE);
}
@Test
public void streamedOutput_normal() throws Exception {
StreamRecorder<SimpleResponse> recorder = StreamRecorder.create();
streamingClient.staticStreamedOutputCall(REQUEST_MESSAGE, recorder);
recorder.awaitCompletion();
assertThat(recorder.getValues()).containsExactly(RESPONSE_MESSAGE, RESPONSE_MESSAGE);
}
@Test
public void error_noMessage() throws Exception {
StatusRuntimeException t = (StatusRuntimeException) catchThrowable(
() -> blockingClient.errorNoMessage(GrpcTestUtil.REQUEST_MESSAGE));
assertThat(t.getStatus().getCode()).isEqualTo(Code.ABORTED);
assertThat(t.getStatus().getDescription()).isNull();
}
@Test
public void error_withMessage() throws Exception {
StatusRuntimeException t = (StatusRuntimeException) catchThrowable(
() -> blockingClient.errorWithMessage(GrpcTestUtil.REQUEST_MESSAGE));
assertThat(t.getStatus().getCode()).isEqualTo(Code.ABORTED);
assertThat(t.getStatus().getDescription()).isEqualTo("aborted call");
}
@Test
public void error_thrown_unary() throws Exception {
StatusRuntimeException t = (StatusRuntimeException) catchThrowable(
() -> blockingClient.unaryThrowsError(GrpcTestUtil.REQUEST_MESSAGE));
assertThat(t.getStatus().getCode()).isEqualTo(Code.ABORTED);
assertThat(t.getStatus().getDescription()).isEqualTo("call aborted");
}
@Test
public void error_thrown_streamMessage() throws Exception {
StreamRecorder<SimpleResponse> response = StreamRecorder.create();
StreamObserver<SimpleRequest> request = streamingClient.streamThrowsError(response);
request.onNext(GrpcTestUtil.REQUEST_MESSAGE);
response.awaitCompletion();
StatusRuntimeException t = (StatusRuntimeException) response.getError();
assertThat(t.getStatus().getCode()).isEqualTo(Code.ABORTED);
assertThat(t.getStatus().getDescription()).isEqualTo("bad streaming message");
}
@Test
public void error_thrown_streamStub() throws Exception {
StreamRecorder<SimpleResponse> response = StreamRecorder.create();
streamingClient.streamThrowsErrorInStub(response);
response.awaitCompletion();
StatusRuntimeException t = (StatusRuntimeException) response.getError();
assertThat(t.getStatus().getCode()).isEqualTo(Code.ABORTED);
assertThat(t.getStatus().getDescription()).isEqualTo("bad streaming stub");
}
@Test
public void requestContextSet() throws Exception {
StreamRecorder<SimpleResponse> response = StreamRecorder.create();
StreamObserver<SimpleRequest> request = streamingClient.checkRequestContext(response);
request.onNext(GrpcTestUtil.REQUEST_MESSAGE);
request.onNext(GrpcTestUtil.REQUEST_MESSAGE);
request.onNext(GrpcTestUtil.REQUEST_MESSAGE);
request.onCompleted();
response.awaitCompletion();
assertThat(response.getValues())
.containsExactly(
SimpleResponse.newBuilder()
.setPayload(Payload.newBuilder()
.setBody(ByteString.copyFromUtf8("3")))
.build());
}
@Test
public void tooLargeRequest_uncompressed() throws Exception {
SimpleRequest request = SimpleRequest.newBuilder()
.setPayload(
Payload.newBuilder()
.setBody(ByteString.copyFrom(
LARGE_PAYLOAD.toByteArray())))
.build();
StatusRuntimeException t =
(StatusRuntimeException) catchThrowable(
() -> blockingClient.staticUnaryCall(request));
// NB: Since GRPC does not support HTTP/1, it just resets the stream with an HTTP/2 CANCEL error code,
// which clients would interpret as Code.CANCELLED. Armeria supports HTTP/1, so more generically returns
// an HTTP 500.
assertThat(t.getStatus().getCode()).isEqualTo(Code.UNKNOWN);
}
@Test
public void tooLargeRequest_compressed() throws Exception {
SimpleRequest request = SimpleRequest.newBuilder()
.setPayload(
Payload.newBuilder()
.setBody(ByteString.copyFrom(
LARGE_PAYLOAD.toByteArray())))
.build();
StatusRuntimeException t =
(StatusRuntimeException) catchThrowable(
() -> blockingClient.withCompression("gzip").staticUnaryCall(request));
// NB: Since GRPC does not support HTTP/1, it just resets the stream with an HTTP/2 CANCEL error code,
// which clients would interpret as Code.CANCELLED. Armeria supports HTTP/1, so more generically returns
// an HTTP 500.
assertThat(t.getStatus().getCode()).isEqualTo(Code.UNKNOWN);
}
@Test
public void uncompressedClient_compressedEndpoint() throws Exception {
ManagedChannel nonDecompressingChannel =
ManagedChannelBuilder.forAddress("127.0.0.1", server.httpPort())
.decompressorRegistry(
DecompressorRegistry.emptyInstance()
.with(Codec.Identity.NONE, false))
.usePlaintext(true)
.build();
UnitTestServiceBlockingStub client = UnitTestServiceGrpc.newBlockingStub(nonDecompressingChannel);
assertThat(client.staticUnaryCallSetsMessageCompression(REQUEST_MESSAGE))
.isEqualTo(RESPONSE_MESSAGE);
nonDecompressingChannel.shutdownNow();
}
@Test
public void compressedClient_compressedEndpoint() throws Exception {
assertThat(blockingClient.staticUnaryCallSetsMessageCompression(REQUEST_MESSAGE))
.isEqualTo(RESPONSE_MESSAGE);
}
@Test
public void unframed() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_UNARY_CALL.getFullMethodName())
.set(HttpHeaderNames.CONTENT_TYPE, "application/protobuf"),
REQUEST_MESSAGE.toByteArray()).aggregate().get();
SimpleResponse message = SimpleResponse.parseFrom(response.content().array());
assertThat(message).isEqualTo(RESPONSE_MESSAGE);
}
@Test
public void unframed_acceptEncoding() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_UNARY_CALL.getFullMethodName())
.set(HttpHeaderNames.CONTENT_TYPE, "application/protobuf")
.set(GrpcHeaderNames.GRPC_ACCEPT_ENCODING, "gzip,none"),
REQUEST_MESSAGE.toByteArray()).aggregate().get();
SimpleResponse message = SimpleResponse.parseFrom(response.content().array());
assertThat(message).isEqualTo(RESPONSE_MESSAGE);
}
@Test
public void unframed_streamingApi() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_STREAMED_OUTPUT_CALL.getFullMethodName())
.set(HttpHeaderNames.CONTENT_TYPE, "application/protobuf"),
StreamingOutputCallRequest.getDefaultInstance().toByteArray()).aggregate().get();
assertThat(response.status()).isEqualTo(HttpStatus.BAD_REQUEST);
}
@Test
public void unframed_noContentType() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_UNARY_CALL.getFullMethodName()),
REQUEST_MESSAGE.toByteArray()).aggregate().get();
assertThat(response.status()).isEqualTo(HttpStatus.UNSUPPORTED_MEDIA_TYPE);
}
@Test
public void unframed_grpcEncoding() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_UNARY_CALL.getFullMethodName())
.set(HttpHeaderNames.CONTENT_TYPE, "application/protobuf")
.set(GrpcHeaderNames.GRPC_ENCODING, "gzip"),
REQUEST_MESSAGE.toByteArray()).aggregate().get();
assertThat(response.status()).isEqualTo(HttpStatus.UNSUPPORTED_MEDIA_TYPE);
}
@Test
public void unframed_serviceError() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_UNARY_CALL.getFullMethodName())
.set(HttpHeaderNames.CONTENT_TYPE, "application/protobuf"),
SimpleRequest.newBuilder()
.setResponseStatus(
EchoStatus.newBuilder()
.setCode(Status.DEADLINE_EXCEEDED.getCode().value()))
.build().toByteArray()).aggregate().get();
assertThat(response.status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR);
}
@Test
public void grpcWeb() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_UNARY_CALL.getFullMethodName())
.set(HttpHeaderNames.CONTENT_TYPE, "application/grpc-web"),
GrpcTestUtil.uncompressedFrame(GrpcTestUtil.requestByteBuf())).aggregate().get();
byte[] serializedStatusHeader = "grpc-status: 0\r\n".getBytes(StandardCharsets.US_ASCII);
byte[] serializedTrailers = Bytes.concat(
new byte[] { ArmeriaServerCall.TRAILERS_FRAME_HEADER },
Ints.toByteArray(serializedStatusHeader.length),
serializedStatusHeader);
assertThat(response.content().array()).containsExactly(
Bytes.concat(
GrpcTestUtil.uncompressedFrame(
GrpcTestUtil.protoByteBuf(GrpcTestUtil.RESPONSE_MESSAGE)),
serializedTrailers));
}
@Test
public void json() throws Exception {
HttpClient client = HttpClientFactory.DEFAULT
.newClient("none+" + server.httpUri("/"),
HttpClient.class);
ByteBuf request = Unpooled.wrappedBuffer(
JsonFormat.printer().print(GrpcTestUtil.REQUEST_MESSAGE).getBytes(StandardCharsets.UTF_8));
AggregatedHttpMessage response = client.execute(
HttpHeaders.of(HttpMethod.POST,
UnitTestServiceGrpc.METHOD_STATIC_UNARY_CALL.getFullMethodName())
.set(HttpHeaderNames.CONTENT_TYPE, "application/grpc+json"),
GrpcTestUtil.uncompressedFrame(request)).aggregate().get();
ByteBuf responseMessage =
Unpooled.wrappedBuffer(JsonFormat.printer()
.print(GrpcTestUtil.RESPONSE_MESSAGE)
.getBytes(StandardCharsets.UTF_8));
assertThat(response.content().array()).containsExactly(GrpcTestUtil.uncompressedFrame(responseMessage));
}
}