/* * Copyright 2017 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.internal.grpc; import java.io.IOException; import java.io.InputStream; import java.util.LinkedList; import java.util.Queue; import java.util.Random; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.annotation.concurrent.GuardedBy; import com.google.common.base.Preconditions; import com.google.common.collect.Queues; import com.google.protobuf.ByteString; import com.linecorp.armeria.common.http.FilteredHttpResponse; import com.linecorp.armeria.common.http.HttpHeaderNames; import com.linecorp.armeria.common.http.HttpHeaders; import com.linecorp.armeria.common.http.HttpObject; import com.linecorp.armeria.common.http.HttpRequest; import com.linecorp.armeria.common.http.HttpResponse; import com.linecorp.armeria.grpc.testing.Messages; import com.linecorp.armeria.grpc.testing.Messages.PayloadType; import com.linecorp.armeria.grpc.testing.Messages.ResponseParameters; import com.linecorp.armeria.grpc.testing.Messages.SimpleRequest; import com.linecorp.armeria.grpc.testing.Messages.SimpleResponse; import com.linecorp.armeria.grpc.testing.Messages.StreamingInputCallRequest; import com.linecorp.armeria.grpc.testing.Messages.StreamingInputCallResponse; import com.linecorp.armeria.grpc.testing.Messages.StreamingOutputCallRequest; import com.linecorp.armeria.grpc.testing.Messages.StreamingOutputCallResponse; import com.linecorp.armeria.grpc.testing.TestServiceGrpc; import com.linecorp.armeria.protobuf.EmptyProtos; import com.linecorp.armeria.protobuf.EmptyProtos.Empty; import com.linecorp.armeria.server.Service; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.SimpleDecoratingService; import io.grpc.Status; import io.grpc.internal.LogExceptionRunnable; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.netty.util.AsciiString; public class TestServiceImpl extends TestServiceGrpc.TestServiceImplBase { public static final AsciiString EXTRA_HEADER_NAME = HttpHeaderNames.of("extra-header"); private static final String UNCOMPRESSABLE_FILE = "/io/grpc/testing/integration/testdata/uncompressable.bin"; private final Random random = new Random(); private final ScheduledExecutorService executor; private final ByteString uncompressableBuffer; private final ByteString compressableBuffer; /** * Constructs a controller using the given executor for scheduling response stream chunks. */ public TestServiceImpl(ScheduledExecutorService executor) { this.executor = executor; this.compressableBuffer = ByteString.copyFrom(new byte[1024]); this.uncompressableBuffer = createBufferFromFile(UNCOMPRESSABLE_FILE); } @Override public void emptyCall(EmptyProtos.Empty empty, StreamObserver<Empty> responseObserver) { responseObserver.onNext(Empty.getDefaultInstance()); responseObserver.onCompleted(); } /** * Immediately responds with a payload of the type and size specified in the request. */ @Override public void unaryCall(SimpleRequest req, StreamObserver<SimpleResponse> responseObserver) { ServerCallStreamObserver<SimpleResponse> obs = (ServerCallStreamObserver<SimpleResponse>) responseObserver; SimpleResponse.Builder responseBuilder = SimpleResponse.newBuilder(); try { switch (req.getResponseCompression()) { case DEFLATE: // fallthrough, just use gzip case GZIP: obs.setCompression("gzip"); break; case NONE: obs.setCompression("identity"); break; case UNRECOGNIZED: // fallthrough default: obs.onError(Status.INVALID_ARGUMENT .withDescription("Unknown: " + req.getResponseCompression()) .asRuntimeException()); return; } } catch (IllegalArgumentException e) { obs.onError(Status.UNIMPLEMENTED .withDescription("compression not supported.") .withCause(e) .asRuntimeException()); return; } if (req.getResponseSize() != 0) { boolean compressable = compressableResponse(req.getResponseType()); ByteString dataBuffer = compressable ? compressableBuffer : uncompressableBuffer; // For consistency with the c++ TestServiceImpl, use a random offset for unary calls. // TODO(wonderfly): whether or not this is a good approach needs further discussion. int offset = random.nextInt( compressable ? compressableBuffer.size() : uncompressableBuffer.size()); ByteString payload = generatePayload(dataBuffer, offset, req.getResponseSize()); responseBuilder.getPayloadBuilder() .setType(compressable ? PayloadType.COMPRESSABLE : PayloadType.UNCOMPRESSABLE) .setBody(payload); } if (req.hasResponseStatus()) { obs.onError(Status.fromCodeValue(req.getResponseStatus().getCode()) .withDescription(req.getResponseStatus().getMessage()) .asRuntimeException()); return; } responseObserver.onNext(responseBuilder.build()); responseObserver.onCompleted(); } /** * Given a request that specifies chunk size and interval between responses, creates and schedules * the response stream. */ @Override public void streamingOutputCall(StreamingOutputCallRequest request, StreamObserver<StreamingOutputCallResponse> responseObserver) { // Create and start the response dispatcher. new ResponseDispatcher(responseObserver).enqueue(toChunkQueue(request)).completeInput(); } /** * Waits until we have received all of the request messages and then returns the aggregate payload * size for all of the received requests. */ @Override public StreamObserver<Messages.StreamingInputCallRequest> streamingInputCall( final StreamObserver<Messages.StreamingInputCallResponse> responseObserver) { return new StreamObserver<StreamingInputCallRequest>() { private int totalPayloadSize; @Override public void onNext(StreamingInputCallRequest message) { totalPayloadSize += message.getPayload().getBody().size(); } @Override public void onCompleted() { responseObserver.onNext(StreamingInputCallResponse.newBuilder() .setAggregatedPayloadSize( totalPayloadSize).build()); responseObserver.onCompleted(); } @Override public void onError(Throwable cause) { responseObserver.onError(cause); } }; } /** * True bi-directional streaming. Processes requests as they come in. Begins streaming results * immediately. */ @Override public StreamObserver<StreamingOutputCallRequest> fullDuplexCall( final StreamObserver<StreamingOutputCallResponse> responseObserver) { final ResponseDispatcher dispatcher = new ResponseDispatcher(responseObserver); return new StreamObserver<StreamingOutputCallRequest>() { @Override public void onNext(StreamingOutputCallRequest request) { if (request.hasResponseStatus()) { dispatcher.cancel(); responseObserver.onError(Status.fromCodeValue(request.getResponseStatus().getCode()) .withDescription( request.getResponseStatus().getMessage()) .asRuntimeException()); return; } dispatcher.enqueue(toChunkQueue(request)); } @Override public void onCompleted() { if (!dispatcher.isCancelled()) { // Tell the dispatcher that all input has been received. dispatcher.completeInput(); } } @Override public void onError(Throwable cause) { responseObserver.onError(cause); } }; } /** * Similar to {@link #fullDuplexCall}, except that it waits for all streaming requests to be * received before starting the streaming responses. */ @Override public StreamObserver<StreamingOutputCallRequest> halfDuplexCall( final StreamObserver<StreamingOutputCallResponse> responseObserver) { final Queue<Chunk> chunks = new LinkedList<Chunk>(); return new StreamObserver<StreamingOutputCallRequest>() { @Override public void onNext(StreamingOutputCallRequest request) { chunks.addAll(toChunkQueue(request)); } @Override public void onCompleted() { // Dispatch all of the chunks in one shot. new ResponseDispatcher(responseObserver).enqueue(chunks).completeInput(); } @Override public void onError(Throwable cause) { responseObserver.onError(cause); } }; } /** * Schedules the dispatch of a queue of chunks. Whenever chunks are added or input is completed, * the next response chunk is scheduled for delivery to the client. When no more chunks are * available, the stream is half-closed. */ private class ResponseDispatcher { private final Chunk completionChunk = new Chunk(0, 0, 0, false); private final Queue<Chunk> chunks; private final StreamObserver<StreamingOutputCallResponse> responseStream; private boolean scheduled; @GuardedBy("this") private boolean cancelled; private Throwable failure; private Runnable dispatchTask = new Runnable() { @Override public void run() { try { // Dispatch the current chunk to the client. try { dispatchChunk(); } catch (RuntimeException e) { // Indicate that nothing is scheduled and re-throw. synchronized (ResponseDispatcher.this) { scheduled = false; } throw e; } // Schedule the next chunk if there is one. synchronized (ResponseDispatcher.this) { // Indicate that nothing is scheduled. scheduled = false; scheduleNextChunk(); } } catch (Throwable t) { t.printStackTrace(); } } }; ResponseDispatcher(StreamObserver<StreamingOutputCallResponse> responseStream) { this.chunks = Queues.newLinkedBlockingQueue(); this.responseStream = responseStream; } /** * Adds the given chunks to the response stream and schedules the next chunk to be delivered if * needed. */ synchronized ResponseDispatcher enqueue(Queue<Chunk> moreChunks) { assertNotFailed(); chunks.addAll(moreChunks); scheduleNextChunk(); return this; } /** * Indicates that the input is completed and the currently enqueued response chunks are all that * remain to be scheduled for dispatch to the client. */ ResponseDispatcher completeInput() { assertNotFailed(); chunks.add(completionChunk); scheduleNextChunk(); return this; } /** * Allows the service to cancel the remaining responses. */ synchronized void cancel() { Preconditions.checkState(!cancelled, "Dispatcher already cancelled"); chunks.clear(); cancelled = true; } synchronized boolean isCancelled() { return cancelled; } /** * Dispatches the current response chunk to the client. This is only called by the executor. At * any time, a given dispatch task should only be registered with the executor once. */ private synchronized void dispatchChunk() { if (cancelled) { return; } try { // Pop off the next chunk and send it to the client. Chunk chunk = chunks.remove(); if (chunk == completionChunk) { responseStream.onCompleted(); } else { responseStream.onNext(chunk.toResponse()); } } catch (Throwable e) { failure = e; if (Status.fromThrowable(e).getCode() == Status.CANCELLED.getCode()) { // Stream was cancelled by client, responseStream.onError() might be called already or // will be called soon by inbounding StreamObserver. chunks.clear(); } else { responseStream.onError(e); } } } /** * Schedules the next response chunk to be dispatched. If all input has been received and there * are no more chunks in the queue, the stream is closed. */ private void scheduleNextChunk() { synchronized (this) { if (scheduled) { // Dispatch task is already scheduled. return; } // Schedule the next response chunk if there is one. Chunk nextChunk = chunks.peek(); if (nextChunk != null) { scheduled = true; // TODO(ejona): cancel future if RPC is cancelled Future<?> unused = executor.schedule(new LogExceptionRunnable(dispatchTask), nextChunk.delayMicroseconds, TimeUnit.MICROSECONDS); return; } } } private void assertNotFailed() { if (failure != null) { throw new IllegalStateException("Stream already failed", failure); } } } /** * Breaks down the request and creates a queue of response chunks for the given request. */ Queue<Chunk> toChunkQueue(StreamingOutputCallRequest request) { Queue<Chunk> chunkQueue = new LinkedList<Chunk>(); int offset = 0; boolean compressable = compressableResponse(request.getResponseType()); for (ResponseParameters params : request.getResponseParametersList()) { chunkQueue.add(new Chunk(params.getIntervalUs(), offset, params.getSize(), compressable)); // Increment the offset past this chunk. // Both buffers need to be circular. offset = (offset + params.getSize()) % ( compressable ? compressableBuffer.size() : uncompressableBuffer.size()); } return chunkQueue; } /** * A single chunk of a response stream. Contains delivery information for the dispatcher and can * be converted to a streaming response proto. A chunk just references it's payload in the * {@link #uncompressableBuffer} array. The payload isn't actually created until {@link * #toResponse()} is called. */ private final class Chunk { private final int delayMicroseconds; private final int offset; private final int length; private final boolean compressable; private Chunk(int delayMicroseconds, int offset, int length, boolean compressable) { this.delayMicroseconds = delayMicroseconds; this.offset = offset; this.length = length; this.compressable = compressable; } /** * Convert this chunk into a streaming response proto. */ private StreamingOutputCallResponse toResponse() { StreamingOutputCallResponse.Builder responseBuilder = StreamingOutputCallResponse.newBuilder(); ByteString dataBuffer = compressable ? compressableBuffer : uncompressableBuffer; ByteString payload = generatePayload(dataBuffer, offset, length); responseBuilder.getPayloadBuilder() .setType(compressable ? PayloadType.COMPRESSABLE : PayloadType.UNCOMPRESSABLE) .setBody(payload); return responseBuilder.build(); } } /** * Creates a buffer with data read from a file. */ @SuppressWarnings("Finally") // Not concerned about suppression; expected to be exceedingly rare private ByteString createBufferFromFile(String fileClassPath) { ByteString buffer = ByteString.EMPTY; InputStream inputStream = getClass().getResourceAsStream(fileClassPath); if (inputStream == null) { throw new IllegalArgumentException("Unable to locate file on classpath: " + fileClassPath); } try { buffer = ByteString.readFrom(inputStream); } catch (IOException e) { throw new RuntimeException(e); } finally { try { inputStream.close(); } catch (IOException ignorable) { // ignore } } return buffer; } /** * Indicates whether or not the response for this type should be compressable. If {@code RANDOM}, * picks a random boolean. */ private boolean compressableResponse(PayloadType responseType) { switch (responseType) { case COMPRESSABLE: return true; case RANDOM: return random.nextBoolean(); case UNCOMPRESSABLE: default: return false; } } /** * Generates a payload of desired type and size. Reads compressableBuffer or * uncompressableBuffer as a circular buffer. */ private ByteString generatePayload(ByteString dataBuffer, int offset, int size) { ByteString payload = ByteString.EMPTY; // This offset would never pass the array boundary. int begin = offset; int end = 0; int bytesLeft = size; while (bytesLeft > 0) { end = Math.min(begin + bytesLeft, dataBuffer.size()); // ByteString.substring returns the substring from begin, inclusive, to end, exclusive. payload = payload.concat(dataBuffer.substring(begin, end)); bytesLeft -= (end - begin); begin = end % dataBuffer.size(); } return payload; } public static class EchoRequestHeadersInTrailers extends SimpleDecoratingService<HttpRequest, HttpResponse> { /** * Creates a new instance that decorates the specified {@link Service}. */ public EchoRequestHeadersInTrailers(Service<HttpRequest, HttpResponse> delegate) { super(delegate); } @Override public HttpResponse serve(ServiceRequestContext ctx, HttpRequest req) throws Exception { HttpResponse res = delegate().serve(ctx, req); return new FilteredHttpResponse(res) { private boolean headersReceived; @Override protected HttpObject filter(HttpObject obj) { if (obj instanceof HttpHeaders) { if (!headersReceived) { headersReceived = true; } else { HttpHeaders trailers = (HttpHeaders) obj; String extraHeader = req.headers().get(EXTRA_HEADER_NAME); if (extraHeader != null) { trailers.set(EXTRA_HEADER_NAME, extraHeader); } } } return obj; } }; } } }