/**
* Copyright 2016 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
package com.github.ambry.tools.perf.rest;
import com.codahale.metrics.Counter;
import com.codahale.metrics.Histogram;
import com.codahale.metrics.JmxReporter;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Snapshot;
import com.github.ambry.commons.SSLFactory;
import com.github.ambry.config.SSLConfig;
import com.github.ambry.config.VerifiableProperties;
import com.github.ambry.rest.RestUtils;
import com.github.ambry.utils.Time;
import com.github.ambry.utils.Utils;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpChunkedInput;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedInput;
import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import joptsimple.ArgumentAcceptingOptionSpec;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A Netty based client to evaluate performance of the front end.
*/
public class NettyPerfClient {
private static final String GET = "GET";
private static final String POST = "POST";
private static final List<String> SUPPORTED_REQUEST_TYPES = Arrays.asList(GET, POST);
private static final Logger logger = LoggerFactory.getLogger(NettyPerfClient.class);
private final String host;
private final int port;
private final String uri;
private final int concurrency;
private final long totalSize;
private final byte[] chunk;
private final SSLFactory sslFactory;
private final Bootstrap b = new Bootstrap();
private final ChannelConnectListener channelConnectListener = new ChannelConnectListener();
private final MetricRegistry metricRegistry = new MetricRegistry();
private final JmxReporter reporter = JmxReporter.forRegistry(metricRegistry).build();
private final PerfClientMetrics perfClientMetrics = new PerfClientMetrics(metricRegistry);
private final CountDownLatch shutdownLatch = new CountDownLatch(1);
private final AtomicLong totalRequestCount = new AtomicLong(0);
private EventLoopGroup group;
private long perfClientStartTime;
private volatile boolean isRunning = false;
/**
* Abstraction class for all the parameters that are expected.
*/
private static class ClientArgs {
final String host;
final Integer port;
final String path;
final String requestType;
final Integer concurrency;
final Long postBlobTotalSize;
final Integer postBlobChunkSize;
final String sslPropsFilePath;
private final Logger logger = LoggerFactory.getLogger(getClass());
/**
* Parses the arguments provided and extracts them into variables that can be retrieved.
* @param args the command line argument list.
*/
protected ClientArgs(String args[]) {
OptionParser parser = new OptionParser();
ArgumentAcceptingOptionSpec<String> host = parser.accepts("host", "Front end host to contact")
.withOptionalArg()
.describedAs("host")
.ofType(String.class)
.defaultsTo("localhost");
ArgumentAcceptingOptionSpec<Integer> port = parser.accepts("port", "Front end port")
.withOptionalArg()
.describedAs("port")
.ofType(Integer.class)
.defaultsTo(1174);
ArgumentAcceptingOptionSpec<String> path = parser.accepts("path", "Resource path (prefix with a '/')")
.withOptionalArg()
.describedAs("path")
.ofType(String.class)
.defaultsTo("/");
ArgumentAcceptingOptionSpec<String> requestType =
parser.accepts("requestType", "The type of request to make (POST, GET)")
.withOptionalArg()
.describedAs("requestType")
.ofType(String.class)
.defaultsTo(GET);
ArgumentAcceptingOptionSpec<Integer> concurrency = parser.accepts("concurrency", "Number of parallel requests")
.withOptionalArg()
.describedAs("concurrency")
.ofType(Integer.class)
.defaultsTo(1);
ArgumentAcceptingOptionSpec<Long> postBlobTotalSize =
parser.accepts("postBlobTotalSize", "Total size in bytes of blob to be POSTed")
.withOptionalArg()
.describedAs("postBlobTotalSize")
.ofType(Long.class);
ArgumentAcceptingOptionSpec<Integer> postBlobChunkSize =
parser.accepts("postBlobChunkSize", "Size in bytes of each chunk that will be POSTed")
.withOptionalArg()
.describedAs("postBlobChunkSize")
.ofType(Integer.class);
ArgumentAcceptingOptionSpec<String> sslPropsFilePath =
parser.accepts("sslPropsFilePath", "The path to the properties file with SSL settings")
.withOptionalArg()
.describedAs("sslPropsFilePath")
.ofType(String.class);
OptionSet options = parser.parse(args);
this.host = options.valueOf(host);
this.port = options.valueOf(port);
this.path = options.valueOf(path);
this.requestType = options.valueOf(requestType);
this.concurrency = options.valueOf(concurrency);
this.postBlobTotalSize = options.valueOf(postBlobTotalSize);
this.postBlobChunkSize = options.valueOf(postBlobChunkSize);
this.sslPropsFilePath = options.valueOf(sslPropsFilePath);
validateArgs();
logger.info("Host: {}", this.host);
logger.info("Port: {}", this.port);
logger.info("Path: {}", this.path);
logger.info("Request type: {}", this.requestType);
logger.info("Concurrency: {}", this.concurrency);
logger.info("Post blob total size: {}", this.postBlobTotalSize);
logger.info("Post blob chunk size: {}", this.postBlobChunkSize);
logger.info("SSL properties file path: {}", this.sslPropsFilePath);
}
/**
* Validates the arguments given and verifies relationships b/w them if any exist.
*/
private void validateArgs() {
if (!SUPPORTED_REQUEST_TYPES.contains(requestType)) {
throw new IllegalArgumentException("Unsupported request type: " + requestType);
} else if (requestType.equals(POST) && (postBlobTotalSize == null || postBlobTotalSize <= 0
|| postBlobChunkSize == null || postBlobChunkSize <= 0)) {
throw new IllegalArgumentException(
"Total size to be posted and size of each chunk need to be specified with POST and have to be > 0");
}
}
}
/**
* Invokes the {@link NettyPerfClient} with the command line arguments.
* @param args command line arguments.
*/
public static void main(String[] args) {
try {
ClientArgs clientArgs = new ClientArgs(args);
final NettyPerfClient nettyPerfClient =
new NettyPerfClient(clientArgs.host, clientArgs.port, clientArgs.path, clientArgs.concurrency,
clientArgs.postBlobTotalSize, clientArgs.postBlobChunkSize, clientArgs.sslPropsFilePath);
// attach shutdown handler to catch control-c
Runtime.getRuntime().addShutdownHook(new Thread() {
public void run() {
logger.info("Received shutdown signal. Requesting NettyPerfClient shutdown");
nettyPerfClient.shutdown();
}
});
nettyPerfClient.start();
nettyPerfClient.awaitShutdown();
} catch (Exception e) {
logger.error("Exception during execution of NettyPerfClient", e);
}
}
/**
* Creates an instance of NettyPerfClient
* @param host host to contact.
* @param port port to contact.
* @param path resource path.
* @param concurrency number of parallel requests.
* @param totalSize the total size in bytes of a blob to be POSTed ({@code null} if non-POST).
* @param chunkSize size in bytes of each chunk to be POSTed ({@code null} if non-POST).
* @param sslPropsFilePath the path to the SSL properties, or {@code null} to disable SSL.
* @throws IOException
* @throws GeneralSecurityException
*/
private NettyPerfClient(String host, int port, String path, int concurrency, Long totalSize, Integer chunkSize,
String sslPropsFilePath) throws IOException, GeneralSecurityException {
this.host = host;
this.port = port;
this.uri = "http://" + host + ":" + port + path;
this.concurrency = concurrency;
if (chunkSize != null) {
this.totalSize = totalSize;
chunk = new byte[chunkSize];
new Random().nextBytes(chunk);
} else {
this.totalSize = 0;
chunk = null;
}
sslFactory = sslPropsFilePath != null ? new SSLFactory(
new SSLConfig(new VerifiableProperties(Utils.loadProps(sslPropsFilePath)))) : null;
logger.info("Instantiated NettyPerfClient which will interact with host {}, port {}, uri {} with concurrency {}",
this.host, this.port, uri, this.concurrency);
}
/**
* Starts the NettyPerfClient.
* @throws InterruptedException
*/
protected void start() throws InterruptedException {
logger.info("Starting NettyPerfClient");
reporter.start();
group = new NioEventLoopGroup(concurrency);
b.group(group).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
if (sslFactory != null) {
ch.pipeline()
.addLast("sslHandler", new SslHandler(sslFactory.createSSLEngine(host, port, SSLFactory.Mode.CLIENT)));
}
ch.pipeline().addLast(new HttpClientCodec()).addLast(new ChunkedWriteHandler()).addLast(new ResponseHandler());
}
});
logger.info("Connecting to {}:{}", host, port);
b.remoteAddress(host, port);
perfClientStartTime = System.currentTimeMillis();
for (int i = 0; i < concurrency; i++) {
b.connect().addListener(channelConnectListener);
}
isRunning = true;
logger.info("Created {} channel(s)", concurrency);
logger.info("NettyPerfClient started");
}
/**
* Shuts down the NettyPerfClient.
*/
protected void shutdown() {
logger.info("Shutting down NettyPerfClient");
isRunning = false;
group.shutdownGracefully();
long totalRunTimeInMs = System.currentTimeMillis() - perfClientStartTime;
try {
if (!group.awaitTermination(5, TimeUnit.SECONDS)) {
logger.error("Netty worker did not shutdown within timeout");
} else {
logger.info("NettyPerfClient shutdown complete");
}
} catch (InterruptedException e) {
logger.error("NettyPerfClient shutdown interrupted", e);
} finally {
logger.info("Executed for approximately {} s and sent {} requests ({} requests/sec)",
(float) totalRunTimeInMs / (float) Time.MsPerSec, totalRequestCount.get(),
(float) totalRequestCount.get() * (float) Time.MsPerSec / (float) totalRunTimeInMs);
Snapshot rttStatsSnapshot = perfClientMetrics.requestRoundTripTimeInMs.getSnapshot();
logger.info("RTT stats: Min - {} ms, Mean - {} ms, Max - {} ms", rttStatsSnapshot.getMin(),
rttStatsSnapshot.getMean(), rttStatsSnapshot.getMax());
logger.info("RTT stats: 95th percentile - {} ms, 99th percentile - {} ms, 999th percentile - {} ms",
rttStatsSnapshot.get95thPercentile(), rttStatsSnapshot.get99thPercentile(),
rttStatsSnapshot.get999thPercentile());
reporter.stop();
shutdownLatch.countDown();
}
}
/**
* Blocking function to wait on the NettyPerfClient shutting down.
* @throws InterruptedException
*/
protected void awaitShutdown() throws InterruptedException {
shutdownLatch.await();
}
/**
* Custom handler that sends out the request and receives and processes the response.
*/
private class ResponseHandler extends SimpleChannelInboundHandler<HttpObject> {
private final Logger logger = LoggerFactory.getLogger(getClass());
private HttpRequest request;
private HttpResponse response;
private ChunkedInput<HttpContent> chunkedInput;
private int chunksReceived;
private long sizeReceived;
private long lastChunkReceiveTime;
private long requestStartTime;
private long requestId = 0;
@Override
public void channelActive(ChannelHandlerContext ctx) {
perfClientMetrics.channelCreationRate.mark();
logger.trace("Channel {} active", ctx.channel());
sendRequest(ctx);
}
@Override
public void channelRead0(ChannelHandlerContext ctx, HttpObject in) {
long currentChunkReceiveTime = System.currentTimeMillis();
boolean recognized = false;
if (in instanceof HttpResponse) {
recognized = true;
long responseReceiveStart = currentChunkReceiveTime - requestStartTime;
perfClientMetrics.timeToFirstResponseChunkInMs.update(responseReceiveStart);
logger.trace("Response receive has started on channel {}. Took {} ms", ctx.channel(), responseReceiveStart);
response = (HttpResponse) in;
}
if (in instanceof HttpContent) {
recognized = true;
perfClientMetrics.delayBetweenChunkReceiveInMs.update(currentChunkReceiveTime - lastChunkReceiveTime);
chunksReceived++;
int bytesReceivedThisTime = ((HttpContent) in).content().readableBytes();
sizeReceived += bytesReceivedThisTime;
perfClientMetrics.bytesReceiveRate.mark(bytesReceivedThisTime);
if (in instanceof LastHttpContent) {
long requestRoundTripTime = currentChunkReceiveTime - requestStartTime;
perfClientMetrics.requestRoundTripTimeInMs.update(requestRoundTripTime);
perfClientMetrics.getContentSizeInBytes.update(sizeReceived);
perfClientMetrics.getChunkCount.update(chunksReceived);
logger.trace(
"Final content received on channel {}. Took {} ms. Total chunks received - {}. Total size received - {}",
ctx.channel(), requestRoundTripTime, chunksReceived, sizeReceived);
if (HttpUtil.isKeepAlive(response) && isRunning) {
logger.trace("Sending new request on channel {}", ctx.channel());
sendRequest(ctx);
} else if (!isRunning) {
logger.info("Closing channel {} because NettyPerfClient has been shutdown", ctx.channel());
ctx.close();
} else {
perfClientMetrics.requestResponseError.inc();
logger.error("Channel {} not kept alive. Last response status was {}", ctx.channel(), response.status());
ctx.close();
}
}
}
if (!recognized) {
throw new IllegalStateException("Unexpected HttpObject - " + in.getClass());
}
lastChunkReceiveTime = currentChunkReceiveTime;
}
@Override
public void channelInactive(ChannelHandlerContext ctx) {
logger.trace("Channel {} inactive", ctx.channel());
ctx.close();
if (isRunning) {
perfClientMetrics.unexpectedDisconnectionError.inc();
logger.info("Creating a new channel to keep up concurrency");
b.connect().addListener(channelConnectListener);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
perfClientMetrics.requestResponseError.inc();
logger.error("Exception caught on channel {} while processing request/response", ctx.channel(), cause);
ctx.close();
}
/**
* Sends the request according to the configuration.
* @param ctx the {@link ChannelHandlerContext} to use to send the request.
*/
private void sendRequest(ChannelHandlerContext ctx) {
requestId++;
long globalId = totalRequestCount.incrementAndGet();
logger.trace("Sending request with global ID {} and local ID {} on channel {}", globalId, requestId,
ctx.channel());
reset();
perfClientMetrics.requestRate.mark();
ctx.writeAndFlush(request);
if (request.method().equals(HttpMethod.POST)) {
ctx.writeAndFlush(chunkedInput);
}
logger.trace("Request {} scheduled to be sent on channel {}", requestId, ctx.channel());
}
/**
* Resets all state in preparation for the next request-response.
*/
private void reset() {
if (chunk != null) {
chunkedInput = new HttpChunkedInput(new RepeatedBytesInput());
request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri);
HttpUtil.setContentLength(request, totalSize);
request.headers().add(RestUtils.Headers.BLOB_SIZE, totalSize);
request.headers().add(RestUtils.Headers.SERVICE_ID, "PerfNettyClient");
request.headers().add(RestUtils.Headers.AMBRY_CONTENT_TYPE, "application/octet-stream");
} else {
request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);
}
chunksReceived = 0;
sizeReceived = 0;
lastChunkReceiveTime = 0;
requestStartTime = System.currentTimeMillis();
response = null;
}
/**
* Returns a chunk with the same data again and again until a fixed size is reached.
*/
private class RepeatedBytesInput implements ChunkedInput<ByteBuf> {
private final AtomicBoolean metricRecorded = new AtomicBoolean(false);
private long streamed = 0;
private long startTime;
private long lastChunkSendTime = 0;
private final Logger logger = LoggerFactory.getLogger(getClass());
/**
* Creates an instance that repeatedly sends the same chunk up to the configured size.
*/
protected RepeatedBytesInput() {
if (totalSize < 0 || (totalSize > 0 && chunk.length < 1)) {
throw new IllegalArgumentException("Invalid argument(s)");
}
}
@Override
public boolean isEndOfInput() {
boolean isEndOfInput = streamed >= totalSize;
if (isEndOfInput && metricRecorded.compareAndSet(false, true)) {
long postChunksTime = System.currentTimeMillis() - startTime;
perfClientMetrics.postChunksTimeInMs.update(postChunksTime);
logger.debug("Took {} ms to POST the blob of size {}", postChunksTime, streamed);
}
return isEndOfInput;
}
@Override
public void close() {
logger.debug("Size streamed - {}", streamed);
}
@Override
public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception {
return readChunk(ctx.alloc());
}
@Override
public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
ByteBuf buf = null;
if (streamed == 0) {
startTime = System.currentTimeMillis();
}
if (!isEndOfInput()) {
long currentChunkSendTime = System.currentTimeMillis();
int remaining = (totalSize - streamed) > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) (totalSize - streamed);
int toWrite = Math.min(chunk.length, remaining);
buf = Unpooled.wrappedBuffer(chunk, 0, toWrite);
streamed += toWrite;
if (lastChunkSendTime > 0) {
perfClientMetrics.delayBetweenChunkSendInMs.update(currentChunkSendTime - lastChunkSendTime);
}
lastChunkSendTime = currentChunkSendTime;
}
return buf;
}
@Override
public long length() {
return totalSize;
}
@Override
public long progress() {
return streamed;
}
}
}
/**
* Channel connection listener that prints error if channel could not be connected.
*/
private class ChannelConnectListener implements GenericFutureListener<ChannelFuture> {
@Override
public void operationComplete(ChannelFuture future) {
if (!future.isSuccess()) {
perfClientMetrics.connectError.inc();
logger.error("Channel {} to {}:{} could not be connected.", future.channel(), host, port, future.cause());
}
}
}
/**
* Metrics that track peformance.
*/
private static class PerfClientMetrics {
public final Meter bytesReceiveRate;
public final Meter channelCreationRate;
public final Meter requestRate;
public final Histogram delayBetweenChunkReceiveInMs;
public final Histogram delayBetweenChunkSendInMs;
public final Histogram getContentSizeInBytes;
public final Histogram getChunkCount;
public final Histogram postChunksTimeInMs;
public final Histogram requestRoundTripTimeInMs;
public final Histogram timeToFirstResponseChunkInMs;
public final Counter connectError;
public final Counter requestResponseError;
public final Counter unexpectedDisconnectionError;
/**
* Creates an instance of PerfClientMetrics.
* @param metricRegistry the {@link MetricRegistry} instance to use.
*/
protected PerfClientMetrics(MetricRegistry metricRegistry) {
bytesReceiveRate = metricRegistry.meter(MetricRegistry.name(ResponseHandler.class, "BytesReceiveRate"));
channelCreationRate = metricRegistry.meter(MetricRegistry.name(ResponseHandler.class, "ChannelCreationRate"));
requestRate = metricRegistry.meter(MetricRegistry.name(ResponseHandler.class, "RequestRate"));
delayBetweenChunkReceiveInMs =
metricRegistry.histogram(MetricRegistry.name(ResponseHandler.class, "DelayBetweenChunkReceiveInMs"));
delayBetweenChunkSendInMs = metricRegistry.histogram(
MetricRegistry.name(ResponseHandler.RepeatedBytesInput.class, "DelayBetweenChunkSendInMs"));
getContentSizeInBytes = metricRegistry.histogram(
MetricRegistry.name(ResponseHandler.RepeatedBytesInput.class, "GetContentSizeInBytes"));
getChunkCount =
metricRegistry.histogram(MetricRegistry.name(ResponseHandler.RepeatedBytesInput.class, "GetChunkCount"));
postChunksTimeInMs =
metricRegistry.histogram(MetricRegistry.name(ResponseHandler.RepeatedBytesInput.class, "PostChunksTimeInMs"));
requestRoundTripTimeInMs =
metricRegistry.histogram(MetricRegistry.name(ResponseHandler.class, "RequestRoundTripTimeInMs"));
timeToFirstResponseChunkInMs =
metricRegistry.histogram(MetricRegistry.name(ResponseHandler.class, "TimeToFirstResponseChunkInMs"));
connectError = metricRegistry.counter(MetricRegistry.name(ResponseHandler.class, "ConnectError"));
requestResponseError = metricRegistry.counter(MetricRegistry.name(ResponseHandler.class, "RequestResponseError"));
unexpectedDisconnectionError =
metricRegistry.counter(MetricRegistry.name(ResponseHandler.class, "UnexpectedDisconnectionError"));
}
}
}