/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.netty4.pipelining.HttpPipelinedRequest;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import static org.elasticsearch.test.hamcrest.RegexMatcher.matches;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.hasSize;
/**
* This test just tests, if he pipelining works in general with out any connection the Elasticsearch handler
*/
public class Netty4HttpServerPipeliningTests extends ESTestCase {
private NetworkService networkService;
private ThreadPool threadPool;
private MockBigArrays bigArrays;
@Before
public void setup() throws Exception {
networkService = new NetworkService(Settings.EMPTY, Collections.emptyList());
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(Settings.EMPTY, new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
}
public void testThatHttpPipeliningWorksWhenEnabled() throws Exception {
final Settings settings = Settings.builder()
.put("http.pipelining", true)
.put("http.port", "0")
.build();
try (HttpServerTransport httpServerTransport = new CustomNettyHttpServerTransport(settings)) {
httpServerTransport.start();
final TransportAddress transportAddress = randomFrom(httpServerTransport.boundAddress().boundAddresses());
final int numberOfRequests = randomIntBetween(4, 16);
final List<String> requests = new ArrayList<>(numberOfRequests);
for (int i = 0; i < numberOfRequests; i++) {
if (rarely()) {
requests.add("/slow/" + i);
} else {
requests.add("/" + i);
}
}
try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) {
Collection<FullHttpResponse> responses = nettyHttpClient.get(transportAddress.address(), requests.toArray(new String[]{}));
Collection<String> responseBodies = Netty4HttpClient.returnHttpResponseBodies(responses);
assertThat(responseBodies, contains(requests.toArray()));
}
}
}
public void testThatHttpPipeliningCanBeDisabled() throws Exception {
final Settings settings = Settings.builder()
.put("http.pipelining", false)
.put("http.port", "0")
.build();
try (HttpServerTransport httpServerTransport = new CustomNettyHttpServerTransport(settings)) {
httpServerTransport.start();
final TransportAddress transportAddress = randomFrom(httpServerTransport.boundAddress().boundAddresses());
final int numberOfRequests = randomIntBetween(4, 16);
final Set<Integer> slowIds = new HashSet<>();
final List<String> requests = new ArrayList<>(numberOfRequests);
for (int i = 0; i < numberOfRequests; i++) {
if (rarely()) {
requests.add("/slow/" + i);
slowIds.add(i);
} else {
requests.add("/" + i);
}
}
try (Netty4HttpClient nettyHttpClient = new Netty4HttpClient()) {
Collection<FullHttpResponse> responses = nettyHttpClient.get(transportAddress.address(), requests.toArray(new String[]{}));
List<String> responseBodies = new ArrayList<>(Netty4HttpClient.returnHttpResponseBodies(responses));
// we can not be sure about the order of the responses, but the slow ones should come last
assertThat(responseBodies, hasSize(numberOfRequests));
for (int i = 0; i < numberOfRequests - slowIds.size(); i++) {
assertThat(responseBodies.get(i), matches("/\\d+"));
}
final Set<Integer> ids = new HashSet<>();
for (int i = 0; i < slowIds.size(); i++) {
final String response = responseBodies.get(numberOfRequests - slowIds.size() + i);
assertThat(response, matches("/slow/\\d+" ));
assertTrue(ids.add(Integer.parseInt(response.split("/")[2])));
}
assertThat(slowIds, equalTo(ids));
}
}
}
class CustomNettyHttpServerTransport extends Netty4HttpServerTransport {
private final ExecutorService executorService = Executors.newCachedThreadPool();
CustomNettyHttpServerTransport(final Settings settings) {
super(settings,
Netty4HttpServerPipeliningTests.this.networkService,
Netty4HttpServerPipeliningTests.this.bigArrays,
Netty4HttpServerPipeliningTests.this.threadPool,
xContentRegistry(), new NullDispatcher());
}
@Override
public ChannelHandler configureServerChannelHandler() {
return new CustomHttpChannelHandler(this, executorService, Netty4HttpServerPipeliningTests.this.threadPool.getThreadContext());
}
@Override
protected void doClose() {
executorService.shutdown();
super.doClose();
}
}
private class CustomHttpChannelHandler extends Netty4HttpServerTransport.HttpChannelHandler {
private final ExecutorService executorService;
CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService, ThreadContext threadContext) {
super(transport, randomBoolean(), threadContext);
this.executorService = executorService;
}
@Override
protected void initChannel(Channel ch) throws Exception {
super.initChannel(ch);
ch.pipeline().replace("handler", "handler", new PossiblySlowUpstreamHandler(executorService));
}
}
class PossiblySlowUpstreamHandler extends SimpleChannelInboundHandler<Object> {
private final ExecutorService executorService;
PossiblySlowUpstreamHandler(ExecutorService executorService) {
this.executorService = executorService;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
executorService.submit(new PossiblySlowRunnable(ctx, msg));
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.info("Caught exception", cause);
ctx.channel().close().sync();
}
}
class PossiblySlowRunnable implements Runnable {
private ChannelHandlerContext ctx;
private HttpPipelinedRequest pipelinedRequest;
private FullHttpRequest fullHttpRequest;
PossiblySlowRunnable(ChannelHandlerContext ctx, Object msg) {
this.ctx = ctx;
if (msg instanceof HttpPipelinedRequest) {
this.pipelinedRequest = (HttpPipelinedRequest) msg;
} else if (msg instanceof FullHttpRequest) {
this.fullHttpRequest = (FullHttpRequest) msg;
}
}
@Override
public void run() {
final String uri;
if (pipelinedRequest != null && pipelinedRequest.last() instanceof FullHttpRequest) {
uri = ((FullHttpRequest) pipelinedRequest.last()).uri();
} else {
uri = fullHttpRequest.uri();
}
final ByteBuf buffer = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8);
final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer);
httpResponse.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes());
final boolean slow = uri.matches("/slow/\\d+");
if (slow) {
try {
Thread.sleep(scaledRandomIntBetween(500, 1000));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else {
assert uri.matches("/\\d+");
}
final ChannelPromise promise = ctx.newPromise();
final Object msg;
if (pipelinedRequest != null) {
msg = pipelinedRequest.createHttpResponse(httpResponse, promise);
} else {
msg = httpResponse;
}
ctx.writeAndFlush(msg, promise);
}
}
}