/* * Copyright 2017 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.internal.grpc; import static java.util.Objects.requireNonNull; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import com.google.common.io.ByteStreams; import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.util.JsonFormat; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.grpc.GrpcSerializationFormats; import com.linecorp.armeria.internal.grpc.ArmeriaMessageDeframer.ByteBufOrStream; import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.PrototypeMarshaller; import io.grpc.Status; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufOutputStream; import io.netty.buffer.CompositeByteBuf; /** * Marshaller for GRPC method request or response messages to and from {@link ByteBuf}. Will attempt to use * optimized code paths for known message types, and otherwise delegates to the GRPC stub. */ public class GrpcMessageMarshaller<I, O> { private enum MessageType { UNKNOWN, PROTOBUF } private final ByteBufAllocator alloc; private final SerializationFormat serializationFormat; private final MethodDescriptor<I, O> method; private final MessageType requestType; private final MessageType responseType; public GrpcMessageMarshaller(ByteBufAllocator alloc, SerializationFormat serializationFormat, MethodDescriptor<I, O> method) { this.alloc = requireNonNull(alloc, "alloc"); this.serializationFormat = requireNonNull(serializationFormat, "serializationFormat"); this.method = requireNonNull(method, "method"); requestType = marshallerType(method.getRequestMarshaller()); responseType = marshallerType(method.getResponseMarshaller()); } public ByteBuf serializeRequest(I message) throws IOException { switch (requestType) { case PROTOBUF: return serializeProto((Message) message); default: CompositeByteBuf out = alloc.compositeBuffer(); try (ByteBufOutputStream os = new ByteBufOutputStream(out)) { ByteStreams.copy(method.streamRequest(message), os); } return out; } } public I deserializeRequest(ByteBufOrStream message) throws IOException { InputStream messageStream = message.stream(); if (message.buf() != null) { try { switch (requestType) { case PROTOBUF: PrototypeMarshaller<I> marshaller = (PrototypeMarshaller<I>) method.getRequestMarshaller(); // PrototypeMarshaller<I>.getMessagePrototype will always parse to I @SuppressWarnings("unchecked") I msg = (I) deserializeProto(message.buf(), (Message) marshaller.getMessagePrototype()); return msg; default: // Fallback to using the method's stream marshaller. messageStream = new ByteBufInputStream(message.buf().retain(), true); break; } } finally { message.buf().release(); } } try (InputStream msg = messageStream) { return method.parseRequest(msg); } } public ByteBuf serializeResponse(O message) throws IOException { switch (responseType) { case PROTOBUF: return serializeProto((Message) message); default: CompositeByteBuf out = alloc.compositeBuffer(); try (ByteBufOutputStream os = new ByteBufOutputStream(out)) { ByteStreams.copy(method.streamResponse(message), os); } return out; } } public O deserializeResponse(ByteBufOrStream message) throws IOException { InputStream messageStream = message.stream(); if (message.buf() != null) { try { switch (responseType) { case PROTOBUF: PrototypeMarshaller<O> marshaller = (PrototypeMarshaller<O>) method.getResponseMarshaller(); // PrototypeMarshaller<I>.getMessagePrototype will always parse to I @SuppressWarnings("unchecked") O msg = (O) deserializeProto(message.buf(), (Message) marshaller.getMessagePrototype()); return msg; default: // Fallback to using the method's stream marshaller. messageStream = new ByteBufInputStream(message.buf().retain(), true); break; } } finally { message.buf().release(); } } try (InputStream msg = messageStream) { return method.parseResponse(msg); } } private ByteBuf serializeProto(Message message) throws IOException { if (GrpcSerializationFormats.isProto(serializationFormat)) { ByteBuf buf = alloc.buffer(message.getSerializedSize()); boolean success = false; try { message.writeTo(CodedOutputStream.newInstance(buf.nioBuffer(0, buf.writableBytes()))); buf.writerIndex(buf.capacity()); success = true; } finally { if (!success) { buf.release(); } } return buf; } else if (GrpcSerializationFormats.isJson(serializationFormat)) { ByteBuf buf = alloc.buffer(); boolean success = false; try { buf.writeCharSequence(JsonFormat.printer().print(message), StandardCharsets.UTF_8); success = true; } finally { if (!success) { buf.release(); } } return buf; } throw new IllegalStateException("Unknown serialization format: " + serializationFormat); } private Message deserializeProto(ByteBuf buf, Message prototype) throws IOException { if (GrpcSerializationFormats.isProto(serializationFormat)) { CodedInputStream stream = CodedInputStream.newInstance(buf.nioBuffer()); try { Message msg = prototype.getParserForType().parseFrom(stream); try { stream.checkLastTagWas(0); } catch (InvalidProtocolBufferException e) { e.setUnfinishedMessage(msg); throw e; } return msg; } catch (InvalidProtocolBufferException e) { throw Status.INTERNAL.withDescription("Invalid protobuf byte sequence") .withCause(e).asRuntimeException(); } } else if (GrpcSerializationFormats.isJson(serializationFormat)) { Message.Builder builder = prototype.newBuilderForType(); JsonFormat.parser().merge(buf.toString(StandardCharsets.UTF_8), builder); return builder.build(); } throw new IllegalStateException("Unknown serialization format: " + serializationFormat); } private static MessageType marshallerType(Marshaller<?> marshaller) { return marshaller instanceof PrototypeMarshaller ? MessageType.PROTOBUF : MessageType.UNKNOWN; } }