/* Copyright (c) 2015 LinkedIn Corp. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package com.linkedin.r2.transport.http.client; import com.linkedin.common.util.None; import com.linkedin.data.ByteString; import com.linkedin.r2.RemoteInvocationException; import com.linkedin.r2.filter.R2Constants; import com.linkedin.r2.message.stream.StreamResponseBuilder; import com.linkedin.r2.message.stream.entitystream.EntityStream; import com.linkedin.r2.message.stream.entitystream.EntityStreams; import com.linkedin.r2.message.stream.entitystream.WriteHandle; import com.linkedin.r2.message.stream.entitystream.Writer; import com.linkedin.r2.transport.http.common.HttpConstants; import com.linkedin.r2.util.Timeout; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.TooLongFrameException; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.LastHttpContent; import io.netty.util.AttributeKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.InputStream; import java.nio.channels.ClosedChannelException; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.TimeoutException; /** * This Decoder decodes chunked Netty responses into StreamResponse. * * @author Zhenkai Zhu */ /* package private */ class RAPResponseDecoder extends SimpleChannelInboundHandler<HttpObject> { private static final Logger LOG = LoggerFactory.getLogger(RAPResponseDecoder.class); public static final AttributeKey<Timeout<None>> TIMEOUT_ATTR_KEY = AttributeKey.valueOf("TimeoutExecutor"); private static final FullHttpResponse CONTINUE = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER); private static final int BUFFER_HIGH_WATER_MARK = 3 * R2Constants.DEFAULT_DATA_CHUNK_SIZE; private static final int BUFFER_LOW_WATER_MARK = R2Constants.DEFAULT_DATA_CHUNK_SIZE; private final long _maxContentLength; private TimeoutBufferedWriter _chunkedMessageWriter; boolean _shouldCloseConnection; RAPResponseDecoder(long maxContentLength) { _maxContentLength = maxContentLength; } @Override protected void channelRead0(final ChannelHandlerContext ctx, HttpObject msg) throws Exception { if (msg instanceof HttpResponse) { HttpResponse m = (HttpResponse) msg; _shouldCloseConnection = !HttpUtil.isKeepAlive(m); if (HttpUtil.is100ContinueExpected(m)) { ctx.writeAndFlush(CONTINUE).addListener(new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (!future.isSuccess()) { ctx.fireExceptionCaught(future.cause()); } } }); } if (!m.decoderResult().isSuccess()) { ctx.fireExceptionCaught(m.decoderResult().cause()); return; } // remove chunked encoding. if (HttpUtil.isTransferEncodingChunked(m)) { HttpUtil.setTransferEncodingChunked(m, false); } Timeout<None> timeout = ctx.channel().attr(TIMEOUT_ATTR_KEY).getAndRemove(); if (timeout == null) { LOG.debug("dropped a response after channel inactive or exception had happened."); return; } final TimeoutBufferedWriter writer = new TimeoutBufferedWriter(ctx, _maxContentLength, BUFFER_HIGH_WATER_MARK, BUFFER_LOW_WATER_MARK, timeout); EntityStream entityStream = EntityStreams.newEntityStream(writer); _chunkedMessageWriter = writer; StreamResponseBuilder builder = new StreamResponseBuilder(); builder.setStatus(m.status().code()); for (Map.Entry<String, String> e : m.headers()) { String key = e.getKey(); String value = e.getValue(); if (key.equalsIgnoreCase(HttpConstants.RESPONSE_COOKIE_HEADER_NAME)) { builder.addCookie(value); } else { builder.unsafeAddHeaderValue(key, value); } } ctx.fireChannelRead(builder.build(entityStream)); } else if (msg instanceof HttpContent) { HttpContent chunk = (HttpContent) msg; TimeoutBufferedWriter currentWriter = _chunkedMessageWriter; // Sanity check if (currentWriter == null) { throw new IllegalStateException( "received " + HttpContent.class.getSimpleName() + " without " + HttpResponse.class.getSimpleName()); } if (!chunk.decoderResult().isSuccess()) { this.exceptionCaught(ctx, chunk.decoderResult().cause()); } currentWriter.processHttpChunk(chunk); if (chunk instanceof LastHttpContent) { _chunkedMessageWriter = null; } } else { // something must be wrong, but let's proceed so that // handler after us has a chance to process it. ctx.fireChannelRead(msg); } } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { Timeout<None> timeout = ctx.channel().attr(TIMEOUT_ATTR_KEY).getAndRemove(); if (timeout != null) { timeout.getItem(); } if (_chunkedMessageWriter != null) { _chunkedMessageWriter.fail(new ClosedChannelException()); _chunkedMessageWriter = null; } ctx.fireChannelInactive(); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { Timeout<None> timeout = ctx.channel().attr(TIMEOUT_ATTR_KEY).getAndRemove(); if (timeout != null) { timeout.getItem(); } if (_chunkedMessageWriter != null) { _chunkedMessageWriter.fail(cause); _chunkedMessageWriter = null; } ctx.fireExceptionCaught(cause); } /** * A buffered writer that stops reading from socket if buffered bytes is larger than high water mark * and resumes reading from socket if buffered bytes is smaller than low water mark. */ private class TimeoutBufferedWriter implements Writer { private final ChannelHandlerContext _ctx; private final long _maxContentLength; private final int _highWaterMark; private final int _lowWaterMark; private WriteHandle _wh; private boolean _lastChunkReceived; private int _totalBytesWritten; private int _bufferedBytes; private final List<ByteString> _buffer; private final Timeout<None> _timeout; private volatile Throwable _failureBeforeInit; TimeoutBufferedWriter(final ChannelHandlerContext ctx, long maxContentLength, int highWaterMark, int lowWaterMark, Timeout<None> timeout) { _ctx = ctx; _maxContentLength = maxContentLength; _highWaterMark = highWaterMark; _lowWaterMark = lowWaterMark; _failureBeforeInit = null; _lastChunkReceived = false; _totalBytesWritten = 0; _bufferedBytes = 0; _buffer = new LinkedList<ByteString>(); // schedule a timeout to close the channel and inform use Runnable timeoutTask = new Runnable() { @Override public void run() { _ctx.executor().execute(new Runnable() { @Override public void run() { final Exception ex = new TimeoutException("Timeout while receiving the response entity."); fail(ex); ctx.fireExceptionCaught(ex); } }); } }; _timeout = timeout; _timeout.addTimeoutTask(timeoutTask); } @Override public void onInit(WriteHandle wh) { _wh = wh; } @Override public void onWritePossible() { if (_failureBeforeInit != null) { fail(_failureBeforeInit); return; } if (_ctx.executor().inEventLoop()) { doWrite(); } else { _ctx.executor().execute(new Runnable() { @Override public void run() { doWrite(); } }); } } @Override public void onAbort(Throwable ex) { _timeout.getItem(); _ctx.fireChannelRead(ChannelPoolStreamHandler.CHANNEL_DESTROY_SIGNAL); } public void processHttpChunk(HttpContent chunk) throws TooLongFrameException { if (chunk.content().readableBytes() + _totalBytesWritten > _maxContentLength) { TooLongFrameException ex = new TooLongFrameException("HTTP content length exceeded " + _maxContentLength + " bytes."); fail(ex); _chunkedMessageWriter = null; throw ex; } else { if (chunk.content().isReadable()) { ByteBuf rawData = chunk.content(); InputStream is = new ByteBufInputStream(rawData); final ByteString data; try { data = ByteString.read(is, rawData.readableBytes()); } catch (IOException ex) { fail(ex); return; } _buffer.add(data); _bufferedBytes += data.length(); if (_bufferedBytes > _highWaterMark && _ctx.channel().config().isAutoRead()) { // stop reading from socket because we buffered too much _ctx.channel().config().setAutoRead(false); } } if (chunk instanceof LastHttpContent) { _lastChunkReceived = true; } if (_wh != null) { doWrite(); } } } public void fail(Throwable ex) { _timeout.getItem(); if (_wh != null) { _wh.error(new RemoteInvocationException(ex)); } else { _failureBeforeInit = ex; } } private void doWrite() { while(_wh.remaining() > 0) { if (!_buffer.isEmpty()) { ByteString data = _buffer.remove(0); _wh.write(data); _bufferedBytes -= data.length(); _totalBytesWritten += data.length(); if (!_ctx.channel().config().isAutoRead() && _bufferedBytes < _lowWaterMark) { // resume reading from socket _ctx.channel().config().setAutoRead(true); } } else { if (_lastChunkReceived) { _wh.done(); _timeout.getItem(); if (_shouldCloseConnection) { _ctx.fireChannelRead(ChannelPoolStreamHandler.CHANNEL_DESTROY_SIGNAL); } else { _ctx.fireChannelRead(ChannelPoolStreamHandler.CHANNEL_RELEASE_SIGNAL); } } break; } } } } }