/*
* The MIT License
*
* Copyright 2014 tim.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
package com.mastfrog.acteur.server;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.DefaultByteBufHolder;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpMessage;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaders;
import static io.netty.handler.codec.http.HttpHeaders.is100ContinueExpected;
import static io.netty.handler.codec.http.HttpHeaders.removeTransferEncodingChunked;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.util.CharsetUtil;
import java.util.List;
/**
* Temporarily using a copy of Netty's HttpObjectAggregator - the only difference
* is that instead of sending a DefaultFullHttpResponse for the 100-Continue line,
* we use a ByteBuf. The HttpResponseEncoder is getting removed from the pipeline
* too early somehow - some sort of race condition we need to sort out.
*/
final class HttpObjectAggregator extends MessageToMessageDecoder<HttpObject> {
public static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024;
private final int maxContentLength;
private AggregatedFullHttpMessage currentMessage;
private boolean tooLongFrameFound;
private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS;
private ChannelHandlerContext ctx;
private static final ByteBuf CONTINUE_LINE = Unpooled.copiedBuffer("HTTP/1.1 100 CONTINUE\r\n\r\n",
CharsetUtil.US_ASCII);
/**
* Creates a new instance.
*
* @param maxContentLength
* the maximum length of the aggregated content.
* If the length of the aggregated content exceeds this value,
* a {@link TooLongFrameException} will be raised.
*/
public HttpObjectAggregator(int maxContentLength) {
if (maxContentLength <= 0) {
throw new IllegalArgumentException(
"maxContentLength must be a positive integer: " +
maxContentLength);
}
this.maxContentLength = maxContentLength;
}
/**
* Returns the maximum number of components in the cumulation buffer. If the number of
* the components in the cumulation buffer exceeds this value, the components of the
* cumulation buffer are consolidated into a single component, involving memory copies.
* The default value of this property is {@link #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}.
*/
public final int getMaxCumulationBufferComponents() {
return maxCumulationBufferComponents;
}
/**
* Sets the maximum number of components in the cumulation buffer. If the number of
* the components in the cumulation buffer exceeds this value, the components of the
* cumulation buffer are consolidated into a single component, involving memory copies.
* The default value of this property is {@link #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}
* and its minimum allowed value is {@code 2}.
*/
public final void setMaxCumulationBufferComponents(int maxCumulationBufferComponents) {
if (maxCumulationBufferComponents < 2) {
throw new IllegalArgumentException(
"maxCumulationBufferComponents: " + maxCumulationBufferComponents +
" (expected: >= 2)");
}
if (ctx == null) {
this.maxCumulationBufferComponents = maxCumulationBufferComponents;
} else {
throw new IllegalStateException(
"decoder properties cannot be changed once the decoder is added to a pipeline.");
}
}
@Override
protected void decode(final ChannelHandlerContext ctx, HttpObject msg, List<Object> out) throws Exception {
AggregatedFullHttpMessage currentMessage = this.currentMessage;
if (msg instanceof HttpMessage) {
tooLongFrameFound = false;
assert currentMessage == null;
HttpMessage m = (HttpMessage) msg;
// Handle the 'Expect: 100-continue' header if necessary.
// TODO: Respond with 413 Request Entity Too Large
// and discard the traffic or close the connection.
// No need to notify the upstream handlers - just log.
// If decoding a response, just throw an exception.
if (is100ContinueExpected(m)) {
ByteBuf buf = CONTINUE_LINE.duplicate();
buf.retain();
ctx.writeAndFlush(buf).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
ctx.fireExceptionCaught(future.cause());
}
}
});
}
if (!m.getDecoderResult().isSuccess()) {
removeTransferEncodingChunked(m);
out.add(toFullMessage(m));
this.currentMessage = null;
return;
}
if (msg instanceof HttpRequest) {
HttpRequest header = (HttpRequest) msg;
this.currentMessage = currentMessage = new AggregatedFullHttpRequest(
header, ctx.alloc().compositeBuffer(maxCumulationBufferComponents), null);
} else if (msg instanceof HttpResponse) {
HttpResponse header = (HttpResponse) msg;
this.currentMessage = currentMessage = new AggregatedFullHttpResponse(
header,
ctx.alloc().compositeBuffer(maxCumulationBufferComponents), null);
} else {
throw new Error();
}
// A streamed message - initialize the cumulative buffer, and wait for incoming chunks.
removeTransferEncodingChunked(currentMessage);
} else if (msg instanceof HttpContent) {
if (tooLongFrameFound) {
if (msg instanceof LastHttpContent) {
this.currentMessage = null;
}
// already detect the too long frame so just discard the content
return;
}
assert currentMessage != null;
// Merge the received chunk into the content of the current message.
HttpContent chunk = (HttpContent) msg;
CompositeByteBuf content = (CompositeByteBuf) currentMessage.content();
if (content.readableBytes() > maxContentLength - chunk.content().readableBytes()) {
tooLongFrameFound = true;
// release current message to prevent leaks
currentMessage.release();
this.currentMessage = null;
throw new TooLongFrameException(
"HTTP content length exceeded " + maxContentLength +
" bytes.");
}
// Append the content of the chunk
if (chunk.content().isReadable()) {
chunk.retain();
content.addComponent(chunk.content());
content.writerIndex(content.writerIndex() + chunk.content().readableBytes());
}
final boolean last;
if (!chunk.getDecoderResult().isSuccess()) {
currentMessage.setDecoderResult(
DecoderResult.failure(chunk.getDecoderResult().cause()));
last = true;
} else {
last = chunk instanceof LastHttpContent;
}
if (last) {
this.currentMessage = null;
// Merge trailing headers into the message.
if (chunk instanceof LastHttpContent) {
LastHttpContent trailer = (LastHttpContent) chunk;
currentMessage.setTrailingHeaders(trailer.trailingHeaders());
} else {
currentMessage.setTrailingHeaders(new DefaultHttpHeaders());
}
// Set the 'Content-Length' header.
currentMessage.headers().set(
HttpHeaders.Names.CONTENT_LENGTH,
String.valueOf(content.readableBytes()));
// All done
out.add(currentMessage);
}
} else {
throw new Error();
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
// release current message if it is not null as it may be a left-over
if (currentMessage != null) {
currentMessage.release();
currentMessage = null;
}
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx;
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
super.handlerRemoved(ctx);
// release current message if it is not null as it may be a left-over as there is not much more we can do in
// this case
if (currentMessage != null) {
currentMessage.release();
currentMessage = null;
}
}
private static FullHttpMessage toFullMessage(HttpMessage msg) {
if (msg instanceof FullHttpMessage) {
return ((FullHttpMessage) msg).retain();
}
FullHttpMessage fullMsg;
if (msg instanceof HttpRequest) {
fullMsg = new AggregatedFullHttpRequest(
(HttpRequest) msg, Unpooled.EMPTY_BUFFER, new DefaultHttpHeaders());
} else if (msg instanceof HttpResponse) {
fullMsg = new AggregatedFullHttpResponse(
(HttpResponse) msg, Unpooled.EMPTY_BUFFER, new DefaultHttpHeaders());
} else {
throw new IllegalStateException();
}
return fullMsg;
}
private abstract static class AggregatedFullHttpMessage extends DefaultByteBufHolder implements FullHttpMessage {
protected final HttpMessage message;
private HttpHeaders trailingHeaders;
private AggregatedFullHttpMessage(HttpMessage message, ByteBuf content, HttpHeaders trailingHeaders) {
super(content);
this.message = message;
this.trailingHeaders = trailingHeaders;
}
@Override
public HttpHeaders trailingHeaders() {
return trailingHeaders;
}
public void setTrailingHeaders(HttpHeaders trailingHeaders) {
this.trailingHeaders = trailingHeaders;
}
@Override
public HttpVersion getProtocolVersion() {
return message.getProtocolVersion();
}
@Override
public FullHttpMessage setProtocolVersion(HttpVersion version) {
message.setProtocolVersion(version);
return this;
}
@Override
public HttpHeaders headers() {
return message.headers();
}
@Override
public DecoderResult getDecoderResult() {
return message.getDecoderResult();
}
@Override
public void setDecoderResult(DecoderResult result) {
message.setDecoderResult(result);
}
@Override
public FullHttpMessage retain(int increment) {
super.retain(increment);
return this;
}
@Override
public FullHttpMessage retain() {
super.retain();
return this;
}
@Override
public abstract FullHttpMessage copy();
@Override
public abstract FullHttpMessage duplicate();
}
private static final class AggregatedFullHttpRequest extends AggregatedFullHttpMessage implements FullHttpRequest {
private AggregatedFullHttpRequest(HttpRequest request, ByteBuf content, HttpHeaders trailingHeaders) {
super(request, content, trailingHeaders);
}
@Override
public FullHttpRequest copy() {
DefaultFullHttpRequest copy = new DefaultFullHttpRequest(
getProtocolVersion(), getMethod(), getUri(), content().copy());
copy.headers().set(headers());
copy.trailingHeaders().set(trailingHeaders());
return copy;
}
@Override
public FullHttpRequest duplicate() {
DefaultFullHttpRequest duplicate = new DefaultFullHttpRequest(
getProtocolVersion(), getMethod(), getUri(), content().duplicate());
duplicate.headers().set(headers());
duplicate.trailingHeaders().set(trailingHeaders());
return duplicate;
}
@Override
public FullHttpRequest retain(int increment) {
super.retain(increment);
return this;
}
@Override
public FullHttpRequest retain() {
super.retain();
return this;
}
@Override
public FullHttpRequest setMethod(HttpMethod method) {
((HttpRequest) message).setMethod(method);
return this;
}
@Override
public FullHttpRequest setUri(String uri) {
((HttpRequest) message).setUri(uri);
return this;
}
@Override
public HttpMethod getMethod() {
return ((HttpRequest) message).getMethod();
}
@Override
public String getUri() {
return ((HttpRequest) message).getUri();
}
@Override
public FullHttpRequest setProtocolVersion(HttpVersion version) {
super.setProtocolVersion(version);
return this;
}
}
private static final class AggregatedFullHttpResponse extends AggregatedFullHttpMessage
implements FullHttpResponse {
private AggregatedFullHttpResponse(HttpResponse message, ByteBuf content, HttpHeaders trailingHeaders) {
super(message, content, trailingHeaders);
}
@Override
public FullHttpResponse copy() {
DefaultFullHttpResponse copy = new DefaultFullHttpResponse(
getProtocolVersion(), getStatus(), content().copy());
copy.headers().set(headers());
copy.trailingHeaders().set(trailingHeaders());
return copy;
}
@Override
public FullHttpResponse duplicate() {
DefaultFullHttpResponse duplicate = new DefaultFullHttpResponse(getProtocolVersion(), getStatus(),
content().duplicate());
duplicate.headers().set(headers());
duplicate.trailingHeaders().set(trailingHeaders());
return duplicate;
}
@Override
public FullHttpResponse setStatus(HttpResponseStatus status) {
((HttpResponse) message).setStatus(status);
return this;
}
@Override
public HttpResponseStatus getStatus() {
return ((HttpResponse) message).getStatus();
}
@Override
public FullHttpResponse setProtocolVersion(HttpVersion version) {
super.setProtocolVersion(version);
return this;
}
@Override
public FullHttpResponse retain(int increment) {
super.retain(increment);
return this;
}
@Override
public FullHttpResponse retain() {
super.retain();
return this;
}
}
}