/*
* Copyright 2013-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.glowroot.ui;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Charsets;
import com.google.common.base.Strings;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.net.MediaType;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.EmptyHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.stream.ChunkedInput;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.glowroot.ui.CommonHandler.CommonRequest;
import org.glowroot.ui.CommonHandler.CommonResponse;
import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST;
import static io.netty.handler.codec.http.HttpResponseStatus.FOUND;
import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static java.util.concurrent.TimeUnit.SECONDS;
@Sharable
class HttpServerHandler extends ChannelInboundHandlerAdapter {
private static final Logger logger = LoggerFactory.getLogger(HttpServerHandler.class);
private final ChannelGroup allChannels;
private final Supplier<String> contextPathSupplier;
private final CommonHandler commonHandler;
private final ThreadLocal</*@Nullable*/ Channel> currentChannel =
new ThreadLocal</*@Nullable*/ Channel>();
HttpServerHandler(Supplier<String> contextPathSupplier, CommonHandler commonHandler) {
this.contextPathSupplier = contextPathSupplier;
this.commonHandler = commonHandler;
allChannels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
allChannels.add(ctx.channel());
super.channelActive(ctx);
}
void close(boolean waitForChannelClose) {
if (waitForChannelClose) {
allChannels.close().awaitUninterruptibly();
} else {
allChannels.close().awaitUninterruptibly(1, SECONDS);
}
}
void closeAllButCurrent() {
Channel current = currentChannel.get();
for (Channel channel : allChannels) {
if (channel != current) {
channel.close().awaitUninterruptibly();
}
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) {
ctx.flush();
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
FullHttpRequest request = (FullHttpRequest) msg;
if (request.decoderResult().isFailure()) {
CommonResponse response = new CommonResponse(BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8,
Strings.nullToEmpty(request.decoderResult().cause().getMessage()));
sendResponse(ctx, request, response, false);
return;
}
String uri = request.uri();
logger.debug("channelRead(): request.uri={}", uri);
Channel channel = ctx.channel();
currentChannel.set(channel);
try {
String contextPath = contextPathSupplier.get();
boolean keepAlive = HttpUtil.isKeepAlive(request);
if (!uri.startsWith(contextPath)) {
DefaultFullHttpResponse response = new DefaultFullHttpResponse(HTTP_1_1, FOUND);
response.headers().set(HttpHeaderNames.LOCATION, contextPath);
sendFullResponse(ctx, request, response, keepAlive);
return;
}
QueryStringDecoder decoder = new QueryStringDecoder(stripContextPath(uri, contextPath));
CommonRequest commonRequest = new NettyRequest(request, contextPath, decoder);
CommonResponse response = commonHandler.handle(commonRequest);
if (response.isCloseConnectionAfterPortChange()) {
response.setHeader("Connection", "close");
keepAlive = false;
}
sendResponse(ctx, request, response, keepAlive);
} catch (Exception e) {
logger.error("error handling request {}: {}", uri, e.getMessage(), e);
CommonResponse response =
CommonHandler.newHttpResponseWithStackTrace(e, INTERNAL_SERVER_ERROR, null);
sendResponse(ctx, request, response, false);
} finally {
currentChannel.remove();
request.release();
}
}
private void sendResponse(ChannelHandlerContext ctx, FullHttpRequest request,
CommonResponse response, boolean keepAlive) throws IOException {
Object content = response.getContent();
if (content instanceof ByteBuf) {
FullHttpResponse resp = new DefaultFullHttpResponse(HTTP_1_1, response.getStatus(),
(ByteBuf) content, response.getHeaders(), EmptyHttpHeaders.INSTANCE);
sendFullResponse(ctx, request, resp, keepAlive);
} else if (content instanceof ChunkSource) {
HttpResponse resp = new DefaultHttpResponse(HTTP_1_1, OK, response.getHeaders());
resp.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
ctx.write(resp);
ChunkSource chunkSource = (ChunkSource) content;
ChunkedInput<HttpContent> chunkedInput;
String zipFileName = response.getZipFileName();
if (zipFileName == null) {
chunkedInput = ChunkedInputs.create(chunkSource);
} else {
chunkedInput = ChunkedInputs.createZipFileDownload(chunkSource, zipFileName);
}
ChannelFuture future = ctx.write(chunkedInput);
HttpServices.addErrorListener(future);
if (!keepAlive) {
HttpServices.addCloseListener(future);
}
} else {
throw new IllegalStateException("Unexpected content: " + content.getClass().getName());
}
}
@SuppressWarnings("argument.type.incompatible")
private void sendFullResponse(ChannelHandlerContext ctx, FullHttpRequest request,
FullHttpResponse response, boolean keepAlive) {
response.headers().add(HttpHeaderNames.CONTENT_LENGTH, response.content().readableBytes());
if (keepAlive && !request.protocolVersion().isKeepAliveDefault()) {
response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
}
ChannelFuture f = ctx.write(response);
if (!keepAlive) {
f.addListener(ChannelFutureListener.CLOSE);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (HttpServices.shouldLogException(cause)) {
logger.warn(cause.getMessage(), cause);
}
ctx.close();
}
@VisibleForTesting
static String stripContextPath(String path, String contextPath) {
if (contextPath.equals("/")) {
return path;
}
if (path.equals(contextPath)) {
return "/";
}
return path.substring(contextPath.length());
}
private static class NettyRequest implements CommonRequest {
private final FullHttpRequest request;
private final String contextPath;
private final QueryStringDecoder decoder;
NettyRequest(FullHttpRequest request, String contextPath, QueryStringDecoder decoder) {
this.request = request;
this.contextPath = contextPath;
this.decoder = decoder;
}
@Override
public String getMethod() {
return request.method().name();
}
// includes context path
@Override
public String getUri() {
return request.uri();
}
@Override
public String getContextPath() {
return contextPath;
}
// does not include context path
@Override
public String getPath() {
return decoder.path();
}
@Override
public @Nullable String getHeader(CharSequence name) {
return request.headers().getAsString(name);
}
@Override
public Map<String, List<String>> getParameters() {
return decoder.parameters();
}
@Override
public List<String> getParameters(String name) {
List<String> params = decoder.parameters().get(name);
if (params == null) {
return ImmutableList.of();
} else {
return params;
}
}
@Override
public String getContent() {
return request.content().toString(Charsets.UTF_8);
}
}
}