/* * Copyright 2017 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ /* * Copyright 2014, Google Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following disclaimer * in the documentation and/or other materials provided with the * distribution. * * * Neither the name of Google Inc. nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package com.linecorp.armeria.internal.grpc; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Objects; import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; import com.linecorp.armeria.common.http.HttpData; import io.grpc.Codec; import io.grpc.Codec.Identity; import io.grpc.Decompressor; import io.grpc.Status; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.CompositeByteBuf; /** * A deframer of messages transported in the GRPC wire format. See * <a href="http://www.grpc.io/docs/guides/wire.html">GRPC Wire Protocol</a> for more detail on the protocol. * * <p>The logic has been mostly copied from {@code io.grpc.internal.MessageDeframer}, while removing the buffer * abstraction in favor of using {@link ByteBuf} directly, and allowing the delivery of uncompressed frames as * a {@link ByteBuf} to optimize message parsing. */ public class ArmeriaMessageDeframer implements AutoCloseable { private static final String DEBUG_STRING = ArmeriaMessageDeframer.class.getName(); private static final int HEADER_LENGTH = 5; private static final int COMPRESSED_FLAG_MASK = 1; private static final int RESERVED_MASK = 0xFE; /** * A deframed message. For uncompressed messages, we have the entire buffer available and return it * as is in {@code buf} to optimize parsing. For compressed messages, we will parse incrementally * and thus return a {@link InputStream} in {@code stream}. */ public static class ByteBufOrStream { @Nullable private final ByteBuf buf; @Nullable private final InputStream stream; @VisibleForTesting public ByteBufOrStream(ByteBuf buf) { this(requireNonNull(buf, "buf"), null); } @VisibleForTesting public ByteBufOrStream(InputStream stream) { this(null, requireNonNull(stream, "stream")); } private ByteBufOrStream(@Nullable ByteBuf buf, @Nullable InputStream stream) { this.buf = buf; this.stream = stream; } @Nullable public ByteBuf buf() { return buf; } @Nullable public InputStream stream() { return stream; } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } ByteBufOrStream that = (ByteBufOrStream) o; return Objects.equals(buf, that.buf) && Objects.equals(stream, that.stream); } @Override public int hashCode() { return Objects.hash(buf, stream); } } /** * A listener of deframing events. */ public interface Listener { /** * Called to deliver the next complete message. Either {@code message.buf} or {@code message.stream} * will be non-null. {@code message.buf} must be released, or {@code message.stream} must be closed by * the callee. */ void messageRead(ByteBufOrStream message); /** * Called when the stream is complete and all messages have been successfully delivered. */ void endOfStream(); } private enum State { HEADER, BODY } private final Listener listener; private final int maxMessageSizeBytes; private final ByteBufAllocator alloc; private State state = State.HEADER; private int requiredLength = HEADER_LENGTH; private Decompressor decompressor = Identity.NONE; private boolean compressedFlag; private boolean endOfStream; private CompositeByteBuf nextFrame; private CompositeByteBuf unprocessed; private long pendingDeliveries; private boolean deliveryStalled = true; private boolean inDelivery; private boolean startedDeframing; public ArmeriaMessageDeframer(Listener listener, int maxMessageSizeBytes, ByteBufAllocator alloc) { this.listener = requireNonNull(listener, "listener"); this.maxMessageSizeBytes = maxMessageSizeBytes; this.alloc = requireNonNull(alloc, "alloc"); unprocessed = alloc.compositeBuffer(); } /** * Requests up to the given number of messages from the call to be delivered to * {@link Listener#messageRead(ByteBufOrStream)}. No additional messages will be delivered. * * <p>If {@link #close()} has been called, this method will have no effect. * * @param numMessages the requested number of messages to be delivered to the listener. */ public void request(int numMessages) { checkArgument(numMessages > 0, "numMessages must be > 0"); if (isClosed()) { return; } pendingDeliveries += numMessages; deliver(); } /** * Indicates whether delivery is currently stalled, pending receipt of more data. This means * that no additional data can be delivered to the application. */ public boolean isStalled() { return deliveryStalled; } /** * Adds the given data to this deframer and attempts delivery to the listener. * * @param data the raw data read from the remote endpoint. Must be non-null. * @param endOfStream if {@code true}, indicates that {@code data} is the end of the stream from * the remote endpoint. End of stream should not be used in the event of a transport * error, such as a stream reset. * @throws IllegalStateException if {@link #close()} has been called previously or if * this method has previously been called with {@code endOfStream=true}. */ public void deframe(HttpData data, boolean endOfStream) { requireNonNull(data, "data"); checkNotClosed(); checkState(!this.endOfStream, "Past end of stream"); startedDeframing = true; if (!data.isEmpty()) { ByteBuf buf = alloc.buffer(data.length()); buf.writeBytes(data.array(), data.offset(), data.length()); unprocessed.addComponent(true, buf); } // Indicate that all of the data for this stream has been received. this.endOfStream = endOfStream; deliver(); } /** * Closes this deframer and frees any resources. After this method is called, additional * calls will have no effect. */ @Override public void close() { try { if (unprocessed != null) { unprocessed.release(); } if (nextFrame != null) { nextFrame.release(); } } finally { unprocessed = null; nextFrame = null; } } /** * Indicates whether or not this deframer has been closed. */ public boolean isClosed() { return unprocessed == null; } public ArmeriaMessageDeframer decompressor(Decompressor decompressor) { checkState(!startedDeframing, "Deframing has already started, cannot change decompressor mid-stream."); this.decompressor = decompressor; return this; } /** * Throws if this deframer has already been closed. */ private void checkNotClosed() { checkState(!isClosed(), "MessageDeframer is already closed"); } /** * Reads and delivers as many messages to the listener as possible. */ private void deliver() { // We can have reentrancy here when using a direct executor, triggered by calls to // request more messages. This is safe as we simply loop until pendingDelivers = 0 if (inDelivery) { return; } inDelivery = true; try { // Process the uncompressed bytes. while (pendingDeliveries > 0 && readRequiredBytes()) { switch (state) { case HEADER: processHeader(); break; case BODY: // Read the body and deliver the message. processBody(); // Since we've delivered a message, decrement the number of pending // deliveries remaining. pendingDeliveries--; break; default: throw new IllegalStateException("Invalid state: " + state); } } /* * We are stalled when there are no more bytes to process. This allows delivering errors as * soon as the buffered input has been consumed, independent of whether the application * has requested another message. At this point in the function, either all frames have been * delivered, or unprocessed is empty. If there is a partial message, it will be inside next * frame and not in unprocessed. If there is extra data but no pending deliveries, it will * be in unprocessed. */ boolean stalled = !unprocessed.isReadable(); if (endOfStream && stalled) { boolean havePartialMessage = nextFrame != null && nextFrame.isReadable(); if (!havePartialMessage) { listener.endOfStream(); deliveryStalled = false; return; } else { // We've received the entire stream and have data available but we don't have // enough to read the next frame ... this is bad. throw Status.INTERNAL.withDescription( DEBUG_STRING + ": Encountered end-of-stream mid-frame").asRuntimeException(); } } deliveryStalled = stalled; } finally { inDelivery = false; } } /** * Attempts to read the required bytes into nextFrame. * * @return {@code true} if all of the required bytes have been read. */ private boolean readRequiredBytes() { if (nextFrame == null) { nextFrame = alloc.compositeBuffer(); } // Read until the buffer contains all the required bytes. int missingBytes; while ((missingBytes = requiredLength - nextFrame.readableBytes()) > 0) { int numUnprocessedBytes = unprocessed.readableBytes(); if (numUnprocessedBytes == 0) { // No more data is available. return false; } int toRead = Math.min(missingBytes, numUnprocessedBytes); if (toRead > 0) { nextFrame.addComponent(true, unprocessed.readBytes(toRead)); unprocessed.discardReadComponents(); } } return true; } /** * Processes the GRPC compression header which is composed of the compression flag and the outer * frame length. */ private void processHeader() { int type = nextFrame.readUnsignedByte(); if ((type & RESERVED_MASK) != 0) { throw Status.INTERNAL.withDescription( DEBUG_STRING + ": Frame header malformed: reserved bits not zero") .asRuntimeException(); } compressedFlag = (type & COMPRESSED_FLAG_MASK) != 0; // Update the required length to include the length of the frame. requiredLength = nextFrame.readInt(); if (requiredLength < 0 || requiredLength > maxMessageSizeBytes) { throw Status.RESOURCE_EXHAUSTED.withDescription( String.format("%s: Frame size %d exceeds maximum: %d. ", DEBUG_STRING, requiredLength, maxMessageSizeBytes)).asRuntimeException(); } // Continue reading the frame body. state = State.BODY; } /** * Processes the body of the GRPC compression frame. A single compression frame may contain * several GRPC messages within it. */ private void processBody() { ByteBufOrStream msg = compressedFlag ? getCompressedBody() : getUncompressedBody(); nextFrame = null; listener.messageRead(msg); // Done with this frame, begin processing the next header. state = State.HEADER; requiredLength = HEADER_LENGTH; } private ByteBufOrStream getUncompressedBody() { return new ByteBufOrStream(nextFrame.consolidate()); } private ByteBufOrStream getCompressedBody() { if (decompressor == Codec.Identity.NONE) { throw Status.INTERNAL.withDescription( DEBUG_STRING + ": Can't decode compressed frame as compression not configured.") .asRuntimeException(); } try { // Enforce the maxMessageSizeBytes limit on the returned stream. InputStream unlimitedStream = decompressor.decompress(new ByteBufInputStream(nextFrame, true)); return new ByteBufOrStream( new SizeEnforcingInputStream(unlimitedStream, maxMessageSizeBytes, DEBUG_STRING)); } catch (IOException e) { throw new RuntimeException(e); } } /** * An {@link InputStream} that enforces the {@link #maxMessageSize} limit for compressed frames. */ @VisibleForTesting static final class SizeEnforcingInputStream extends FilterInputStream { private final int maxMessageSize; private final String debugString; private long maxCount; private long count; private long mark = -1; SizeEnforcingInputStream(InputStream in, int maxMessageSize, String debugString) { super(in); this.maxMessageSize = maxMessageSize; this.debugString = debugString; } @Override public int read() throws IOException { int result = in.read(); if (result != -1) { count++; } verifySize(); reportCount(); return result; } @Override public int read(byte[] b, int off, int len) throws IOException { int result = in.read(b, off, len); if (result != -1) { count += result; } verifySize(); reportCount(); return result; } @Override public long skip(long n) throws IOException { long result = in.skip(n); count += result; verifySize(); reportCount(); return result; } @Override public synchronized void mark(int readlimit) { in.mark(readlimit); mark = count; // it's okay to mark even if mark isn't supported, as reset won't work } @Override public synchronized void reset() throws IOException { if (!in.markSupported()) { throw new IOException("Mark not supported"); } if (mark == -1) { throw new IOException("Mark not set"); } in.reset(); count = mark; } private void reportCount() { if (count > maxCount) { maxCount = count; } } private void verifySize() { if (count > maxMessageSize) { throw Status.INTERNAL.withDescription(String.format( "%s: Compressed frame exceeds maximum frame size: %d. Bytes read: %d. ", debugString, maxMessageSize, count)).asRuntimeException(); } } } }