/**
* Copyright 2016 LinkedIn Corp. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
package com.github.ambry.rest;
import com.codahale.metrics.MetricRegistry;
import com.github.ambry.commons.ByteBufferAsyncWritableChannel;
import com.github.ambry.config.NettyConfig;
import com.github.ambry.config.VerifiableProperties;
import com.github.ambry.router.AsyncWritableChannel;
import com.github.ambry.router.Callback;
import com.github.ambry.router.FutureResult;
import com.github.ambry.utils.TestUtils;
import com.github.ambry.utils.Utils;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.buffer.UnpooledHeapByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
import io.netty.channel.DefaultChannelConfig;
import io.netty.channel.DefaultMaxBytesRecvByteBufAllocator;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.DefaultCookie;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Queue;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.junit.Assert;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Tests functionality of {@link NettyRequest}.
*/
public class NettyRequestTest {
private static final int GENERATED_CONTENT_SIZE = 10240;
private static final int GENERATED_CONTENT_PART_COUNT = 10;
private static final int DEFAULT_WATERMARK;
static {
DEFAULT_WATERMARK = new NettyConfig(new VerifiableProperties(new Properties())).nettyServerRequestBufferWatermark;
}
public NettyRequestTest() {
NettyRequest.bufferWatermark = DEFAULT_WATERMARK;
}
/**
* Tests conversion of {@link HttpRequest} to {@link NettyRequest} given good input.
* @throws RestServiceException
*/
@Test
public void conversionWithGoodInputTest() throws RestServiceException, CertificateException, SSLException {
// headers
HttpHeaders headers = new DefaultHttpHeaders(false);
headers.add(HttpHeaderNames.CONTENT_LENGTH, new Random().nextInt(Integer.MAX_VALUE));
headers.add("headerKey", "headerValue1");
headers.add("headerKey", "headerValue2");
headers.add("overLoadedKey", "headerOverloadedValue");
headers.add("paramNoValueInUriButValueInHeader", "paramValueInHeader");
// params
Map<String, List<String>> params = new HashMap<String, List<String>>();
List<String> values = new ArrayList<String>(2);
values.add("paramValue1");
values.add("paramValue2");
params.put("paramKey", values);
values = new ArrayList<String>(1);
values.add("paramOverloadedValue");
params.put("overLoadedKey", values);
params.put("paramNoValue", null);
params.put("paramNoValueInUriButValueInHeader", null);
StringBuilder uriAttachmentBuilder = new StringBuilder("?");
for (Map.Entry<String, List<String>> param : params.entrySet()) {
if (param.getValue() != null) {
for (String value : param.getValue()) {
uriAttachmentBuilder.append(param.getKey()).append("=").append(value).append("&");
}
} else {
uriAttachmentBuilder.append(param.getKey()).append("&");
}
}
uriAttachmentBuilder.deleteCharAt(uriAttachmentBuilder.length() - 1);
String uriAttachment = uriAttachmentBuilder.toString();
NettyRequest nettyRequest;
String uri;
Set<Cookie> cookies = new HashSet<>();
Cookie httpCookie = new DefaultCookie("CookieKey1", "CookieValue1");
cookies.add(httpCookie);
httpCookie = new DefaultCookie("CookieKey2", "CookieValue2");
cookies.add(httpCookie);
headers.add(RestUtils.Headers.COOKIE, getCookiesHeaderValue(cookies));
for (MockChannel channel : Arrays.asList(new MockChannel(), new MockChannel().addSslHandlerToPipeline())) {
uri = "/GET" + uriAttachment;
nettyRequest = createNettyRequest(HttpMethod.GET, uri, headers, channel);
validateRequest(nettyRequest, RestMethod.GET, uri, headers, params, cookies, channel);
closeRequestAndValidate(nettyRequest, channel);
RecvByteBufAllocator savedAllocator = channel.config().getRecvByteBufAllocator();
int[] bufferWatermarks = {-1, 0, 1, DEFAULT_WATERMARK};
for (int bufferWatermark : bufferWatermarks) {
NettyRequest.bufferWatermark = bufferWatermark;
uri = "/POST" + uriAttachment;
nettyRequest = createNettyRequest(HttpMethod.POST, uri, headers, channel);
validateRequest(nettyRequest, RestMethod.POST, uri, headers, params, cookies, channel);
if (bufferWatermark > 0) {
assertTrue("RecvAllocator should have changed",
channel.config().getRecvByteBufAllocator() instanceof DefaultMaxBytesRecvByteBufAllocator);
} else {
assertEquals("RecvAllocator not as expected", savedAllocator, channel.config().getRecvByteBufAllocator());
}
closeRequestAndValidate(nettyRequest, channel);
assertEquals("Allocator not as expected", savedAllocator, channel.config().getRecvByteBufAllocator());
}
for (int bufferWatermark : bufferWatermarks) {
NettyRequest.bufferWatermark = bufferWatermark;
uri = "/PUT" + uriAttachment;
nettyRequest = createNettyRequest(HttpMethod.PUT, uri, headers, channel);
validateRequest(nettyRequest, RestMethod.PUT, uri, headers, params, cookies, channel);
if (bufferWatermark > 0) {
assertTrue("RecvAllocator should have changed",
channel.config().getRecvByteBufAllocator() instanceof DefaultMaxBytesRecvByteBufAllocator);
} else {
assertEquals("RecvAllocator not as expected", savedAllocator, channel.config().getRecvByteBufAllocator());
}
closeRequestAndValidate(nettyRequest, channel);
assertEquals("Allocator not as expected", savedAllocator, channel.config().getRecvByteBufAllocator());
}
NettyRequest.bufferWatermark = DEFAULT_WATERMARK;
uri = "/DELETE" + uriAttachment;
nettyRequest = createNettyRequest(HttpMethod.DELETE, uri, headers, channel);
validateRequest(nettyRequest, RestMethod.DELETE, uri, headers, params, cookies, channel);
closeRequestAndValidate(nettyRequest, channel);
uri = "/HEAD" + uriAttachment;
nettyRequest = createNettyRequest(HttpMethod.HEAD, uri, headers, channel);
validateRequest(nettyRequest, RestMethod.HEAD, uri, headers, params, cookies, channel);
closeRequestAndValidate(nettyRequest, channel);
}
}
/**
* Tests conversion of {@link HttpRequest} to {@link NettyRequest} given bad input (i.e. checks for the correct
* exception and {@link RestServiceErrorCode} if any).
* @throws RestServiceException
*/
@Test
public void conversionWithBadInputTest() throws RestServiceException {
HttpRequest httpRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "");
// HttpRequest null.
try {
new NettyRequest(null, new MockChannel(), new NettyMetrics(new MetricRegistry()));
fail("Provided null HttpRequest to NettyRequest, yet it did not fail");
} catch (IllegalArgumentException e) {
// expected. nothing to do.
}
// Channel null.
try {
new NettyRequest(httpRequest, null, new NettyMetrics(new MetricRegistry()));
fail("Provided null Channel to NettyRequest, yet it did not fail");
} catch (IllegalArgumentException e) {
// expected. nothing to do.
}
// unknown http method
try {
createNettyRequest(HttpMethod.TRACE, "/", null, new MockChannel());
fail("Unknown http method was supplied to NettyRequest. It should have failed to construct");
} catch (RestServiceException e) {
assertEquals("Unexpected RestServiceErrorCode", RestServiceErrorCode.UnsupportedHttpMethod, e.getErrorCode());
}
String[] invalidBlobSizeStrs = {"aba123", "12ab", "-1", "ddsdd", "999999999999999999999999999", "1.234"};
for (String blobSizeStr : invalidBlobSizeStrs) {
// bad blob size
try {
createNettyRequest(HttpMethod.GET, "/", new DefaultHttpHeaders().add(RestUtils.Headers.BLOB_SIZE, blobSizeStr),
new MockChannel());
fail("Bad blob size header was supplied to NettyRequest. It should have failed to construct");
} catch (RestServiceException e) {
assertEquals("Unexpected RestServiceErrorCode", RestServiceErrorCode.InvalidArgs, e.getErrorCode());
}
}
}
/**
* Tests for behavior of multiple operations after {@link NettyRequest#close()} has been called. Some should be ok to
* do and some should throw exceptions.
* @throws Exception
*/
@Test
public void operationsAfterCloseTest() throws Exception {
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
closeRequestAndValidate(nettyRequest, channel);
// operations that should be ok to do (does not include all operations).
nettyRequest.close();
// operations that will throw exceptions.
AsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
try {
nettyRequest.readInto(writeChannel, callback).get();
fail("Request channel has been closed, so read should have thrown ClosedChannelException");
} catch (ExecutionException e) {
Exception exception = (Exception) Utils.getRootCause(e);
assertTrue("Exception is not ClosedChannelException", exception instanceof ClosedChannelException);
callback.awaitCallback();
assertEquals("Exceptions of callback and future differ", exception.getMessage(), callback.exception.getMessage());
}
try {
byte[] content = TestUtils.getRandomBytes(1024);
nettyRequest.addContent(new DefaultLastHttpContent(Unpooled.wrappedBuffer(content)));
fail("Request channel has been closed, so addContent() should have thrown ClosedChannelException");
} catch (RestServiceException e) {
assertEquals("Unexpected RestServiceErrorCode", RestServiceErrorCode.RequestChannelClosed, e.getErrorCode());
}
}
/**
* Tests {@link NettyRequest#addContent(HttpContent)} and
* {@link NettyRequest#readInto(AsyncWritableChannel, Callback)} with different digest algorithms (including a test
* with no digest algorithm).
* @throws Exception
*/
@Test
public void contentAddAndReadTest() throws Exception {
String[] digestAlgorithms = {"", "MD5", "SHA-1", "SHA-256"};
HttpMethod[] methods = {HttpMethod.POST, HttpMethod.PUT};
for (HttpMethod method : methods) {
for (String digestAlgorithm : digestAlgorithms) {
contentAddAndReadTest(digestAlgorithm, true, method);
contentAddAndReadTest(digestAlgorithm, false, method);
}
}
}
/**
* Tests {@link NettyRequest#addContent(HttpContent)} and
* {@link NettyRequest#readInto(AsyncWritableChannel, Callback)} with different digest algorithms (including a test
* with no digest algorithm) and checks that back pressure is applied correctly.
* @throws Exception
*/
@Test
public void backPressureTest() throws Exception {
String[] digestAlgorithms = {"", "MD5", "SHA-1", "SHA-256"};
HttpMethod[] methods = {HttpMethod.POST, HttpMethod.PUT};
for (HttpMethod method : methods) {
for (String digestAlgorithm : digestAlgorithms) {
backPressureTest(digestAlgorithm, true, method);
backPressureTest(digestAlgorithm, false, method);
}
}
}
/**
* Tests exception scenarios of {@link NettyRequest#readInto(AsyncWritableChannel, Callback)} and behavior of
* {@link NettyRequest} when {@link AsyncWritableChannel} instances fail.
* @throws Exception
*/
@Test
public void readIntoExceptionsTest() throws Exception {
Channel channel = new MockChannel();
// try to call readInto twice.
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
AsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
nettyRequest.readInto(writeChannel, null);
try {
nettyRequest.readInto(writeChannel, null);
fail("Calling readInto twice should have failed");
} catch (IllegalStateException e) {
// expected. Nothing to do.
}
closeRequestAndValidate(nettyRequest, channel);
// write into a channel that throws exceptions
// non RuntimeException
nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
List<HttpContent> httpContents = new ArrayList<HttpContent>();
generateContent(httpContents);
assertTrue("Not enough content has been generated", httpContents.size() > 2);
String expectedMsg = "@@expectedMsg@@";
Exception exception = new Exception(expectedMsg);
writeChannel = new BadAsyncWritableChannel(exception);
ReadIntoCallback callback = new ReadIntoCallback();
// add content initially
int addedCount = 0;
for (; addedCount < httpContents.size() / 2; addedCount++) {
HttpContent httpContent = httpContents.get(addedCount);
nettyRequest.addContent(httpContent);
assertEquals("Reference count is not as expected", 2, httpContent.refCnt());
}
Future<Long> future = nettyRequest.readInto(writeChannel, callback);
// add some more content
for (; addedCount < httpContents.size(); addedCount++) {
HttpContent httpContent = httpContents.get(addedCount);
nettyRequest.addContent(httpContent);
}
writeChannel.close();
verifyRefCnts(httpContents);
callback.awaitCallback();
assertNotNull("Exception was not piped correctly", callback.exception);
assertEquals("Exception message mismatch (callback)", expectedMsg, callback.exception.getMessage());
try {
future.get();
fail("Future should have thrown exception");
} catch (ExecutionException e) {
assertEquals("Exception message mismatch (future)", expectedMsg, Utils.getRootCause(e).getMessage());
}
closeRequestAndValidate(nettyRequest, channel);
// RuntimeException
// during readInto
nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
httpContents = new ArrayList<HttpContent>();
generateContent(httpContents);
exception = new IllegalStateException(expectedMsg);
writeChannel = new BadAsyncWritableChannel(exception);
callback = new ReadIntoCallback();
for (HttpContent httpContent : httpContents) {
nettyRequest.addContent(httpContent);
assertEquals("Reference count is not as expected", 2, httpContent.refCnt());
}
try {
nettyRequest.readInto(writeChannel, callback);
fail("readInto did not throw expected exception");
} catch (Exception e) {
assertEquals("Exception caught does not match expected exception", expectedMsg, e.getMessage());
}
writeChannel.close();
closeRequestAndValidate(nettyRequest, channel);
verifyRefCnts(httpContents);
// after readInto
nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
httpContents = new ArrayList<HttpContent>();
generateContent(httpContents);
exception = new IllegalStateException(expectedMsg);
writeChannel = new BadAsyncWritableChannel(exception);
callback = new ReadIntoCallback();
nettyRequest.readInto(writeChannel, callback);
// add content
HttpContent httpContent = httpContents.get(1);
try {
nettyRequest.addContent(httpContent);
fail("addContent did not throw expected exception");
} catch (Exception e) {
assertEquals("Exception caught does not match expected exception", expectedMsg, e.getMessage());
}
writeChannel.close();
closeRequestAndValidate(nettyRequest, channel);
verifyRefCnts(httpContents);
}
/**
* Tests that {@link NettyRequest#close()} leaves any added {@link HttpContent} the way it was before it was added.
* (i.e no reference count changes).
* @throws RestServiceException
*/
@Test
public void closeTest() throws RestServiceException {
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
Queue<HttpContent> httpContents = new LinkedBlockingQueue<HttpContent>();
for (int i = 0; i < 5; i++) {
ByteBuffer content = ByteBuffer.wrap(TestUtils.getRandomBytes(1024));
HttpContent httpContent = new DefaultHttpContent(Unpooled.wrappedBuffer(content));
nettyRequest.addContent(httpContent);
httpContents.add(httpContent);
}
closeRequestAndValidate(nettyRequest, channel);
while (httpContents.peek() != null) {
assertEquals("Reference count of http content has changed", 1, httpContents.poll().refCnt());
}
}
/**
* Tests different state transitions that can happen with {@link NettyRequest#addContent(HttpContent)} for GET
* requests. Some transitions are valid and some should necessarily throw exceptions.
* @throws RestServiceException
*/
@Test
public void addContentForGetTest() throws RestServiceException {
byte[] content = TestUtils.getRandomBytes(16);
// adding non LastHttpContent to nettyRequest
NettyRequest nettyRequest = createNettyRequest(HttpMethod.GET, "/", null, new MockChannel());
try {
nettyRequest.addContent(new DefaultHttpContent(Unpooled.wrappedBuffer(content)));
fail("GET requests should not accept non-LastHTTPContent");
} catch (IllegalStateException e) {
// expected. nothing to do.
}
// adding LastHttpContent with some content to nettyRequest
nettyRequest = createNettyRequest(HttpMethod.GET, "/", null, new MockChannel());
try {
nettyRequest.addContent(new DefaultLastHttpContent(Unpooled.wrappedBuffer(content)));
fail("GET requests should not accept actual content in LastHTTPContent");
} catch (IllegalStateException e) {
// expected. nothing to do.
}
// should accept LastHttpContent just fine.
nettyRequest = createNettyRequest(HttpMethod.GET, "/", null, new MockChannel());
nettyRequest.addContent(new DefaultLastHttpContent());
// should not accept LastHttpContent after close
nettyRequest = createNettyRequest(HttpMethod.GET, "/", null, new MockChannel());
nettyRequest.close();
try {
nettyRequest.addContent(new DefaultLastHttpContent());
fail("Request channel has been closed, so addContent() should have thrown ClosedChannelException");
} catch (RestServiceException e) {
assertEquals("Unexpected RestServiceErrorCode", RestServiceErrorCode.RequestChannelClosed, e.getErrorCode());
}
}
@Test
public void keepAliveTest() throws RestServiceException {
NettyRequest request = createNettyRequest(HttpMethod.GET, "/", null, new MockChannel());
// by default, keep-alive is true for HTTP 1.1
assertTrue("Keep-alive not as expected", request.isKeepAlive());
HttpHeaders headers = new DefaultHttpHeaders();
headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
request = createNettyRequest(HttpMethod.GET, "/", headers, new MockChannel());
assertTrue("Keep-alive not as expected", request.isKeepAlive());
headers = new DefaultHttpHeaders();
headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
request = createNettyRequest(HttpMethod.GET, "/", headers, new MockChannel());
assertFalse("Keep-alive not as expected", request.isKeepAlive());
}
/**
* Tests the {@link NettyRequest#getSize()} function to see that it respects priorities.
* @throws RestServiceException
*/
@Test
public void sizeTest() throws RestServiceException {
// no length headers provided.
NettyRequest nettyRequest = createNettyRequest(HttpMethod.GET, "/", null, new MockChannel());
assertEquals("Size not as expected", -1, nettyRequest.getSize());
// deliberate mismatch to check priorities.
int xAmbryBlobSize = 20;
int contentLength = 10;
// Content-Length header set
HttpHeaders headers = new DefaultHttpHeaders();
headers.add(HttpHeaderNames.CONTENT_LENGTH, contentLength);
nettyRequest = createNettyRequest(HttpMethod.GET, "/", headers, new MockChannel());
assertEquals("Size not as expected", contentLength, nettyRequest.getSize());
// xAmbryBlobSize set
headers = new DefaultHttpHeaders();
headers.add(RestUtils.Headers.BLOB_SIZE, xAmbryBlobSize);
nettyRequest = createNettyRequest(HttpMethod.GET, "/", headers, new MockChannel());
assertEquals("Size not as expected", xAmbryBlobSize, nettyRequest.getSize());
// both set
headers = new DefaultHttpHeaders();
headers.add(RestUtils.Headers.BLOB_SIZE, xAmbryBlobSize);
headers.add(HttpHeaderNames.CONTENT_LENGTH, contentLength);
nettyRequest = createNettyRequest(HttpMethod.GET, "/", headers, new MockChannel());
assertEquals("Size not as expected", xAmbryBlobSize, nettyRequest.getSize());
}
/**
* Tests for POST request that has no content.
* @throws Exception
*/
@Test
public void zeroSizeContentTest() throws Exception {
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
HttpContent httpContent = new DefaultLastHttpContent();
nettyRequest.addContent(httpContent);
assertEquals("Reference count is not as expected", 2, httpContent.refCnt());
ByteBufferAsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
Future<Long> future = nettyRequest.readInto(writeChannel, callback);
assertEquals("There should be no content", 0, writeChannel.getNextChunk().remaining());
writeChannel.resolveOldestChunk(null);
closeRequestAndValidate(nettyRequest, channel);
writeChannel.close();
assertEquals("Reference count of http content has changed", 1, httpContent.refCnt());
callback.awaitCallback();
if (callback.exception != null) {
throw callback.exception;
}
long futureBytesRead = future.get();
assertEquals("Total bytes read does not match (callback)", 0, callback.bytesRead);
assertEquals("Total bytes read does not match (future)", 0, futureBytesRead);
}
/**
* Tests reaction of NettyRequest when content size is different from the size specified in the headers.
* @throws Exception
*/
@Test
public void headerAndContentSizeMismatchTest() throws Exception {
sizeInHeaderMoreThanContentTest();
sizeInHeaderLessThanContentTest();
}
/**
* Does any left over tests for {@link NettyRequest.ContentWriteCallback}
*/
@Test
public void contentWriteCallbackTests() throws RestServiceException {
ReadIntoCallback readIntoCallback = new ReadIntoCallback();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.GET, "/", null, new MockChannel());
NettyRequest.ReadIntoCallbackWrapper wrapper = nettyRequest.new ReadIntoCallbackWrapper(readIntoCallback);
NettyRequest.ContentWriteCallback callback = nettyRequest.new ContentWriteCallback(null, true, wrapper);
long bytesRead = new Random().nextInt(Integer.MAX_VALUE);
// there should be no problem even though httpContent is null.
callback.onCompletion(bytesRead, null);
assertEquals("Bytes read does not match", bytesRead, readIntoCallback.bytesRead);
}
/**
* Tests for incorrect usage of {@link NettyRequest#setDigestAlgorithm(String)} and {@link NettyRequest#getDigest()}.
* @throws NoSuchAlgorithmException
* @throws RestServiceException
*/
@Test
public void digestIncorrectUsageTest() throws NoSuchAlgorithmException, RestServiceException {
setDigestAfterReadTest();
setBadAlgorithmTest();
getDigestWithoutSettingAlgorithmTest();
getDigestBeforeAllContentProcessedTest();
}
// helpers
// general
/**
* Creates a {@link NettyRequest} with the given parameters.
* @param httpMethod the {@link HttpMethod} desired.
* @param uri the URI desired.
* @param headers {@link HttpHeaders} that need to be a part of the request.
* @param channel the {@link Channel} that the request arrived over.
* @return {@link NettyRequest} encapsulating a {@link HttpRequest} with the given parameters.
* @throws RestServiceException if the {@code httpMethod} is not recognized by {@link NettyRequest}.
*/
private NettyRequest createNettyRequest(HttpMethod httpMethod, String uri, HttpHeaders headers, Channel channel)
throws RestServiceException {
MetricRegistry metricRegistry = new MetricRegistry();
RestRequestMetricsTracker.setDefaults(metricRegistry);
HttpRequest httpRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, httpMethod, uri, false);
if (headers != null) {
httpRequest.headers().set(headers);
}
NettyRequest nettyRequest = new NettyRequest(httpRequest, channel, new NettyMetrics(metricRegistry));
assertEquals("Auto-read is in an invalid state",
(!httpMethod.equals(HttpMethod.POST) && !httpMethod.equals(HttpMethod.PUT))
|| NettyRequest.bufferWatermark <= 0, channel.config().isAutoRead());
return nettyRequest;
}
/**
* Closes the provided {@code nettyRequest} and validates that it is actually closed.
* @param nettyRequest the {@link NettyRequest} that needs to be closed and validated.
* @param channel the {@link Channel} over which the request was received.
*/
private void closeRequestAndValidate(NettyRequest nettyRequest, Channel channel) {
nettyRequest.close();
assertFalse("Request channel is not closed", nettyRequest.isOpen());
assertTrue("Auto-read is not as expected", channel.config().isAutoRead());
}
/**
* Convert a set of {@link Cookie} to a string that could be used as header value in http request
* @param cookies that needs conversion
* @return string representation of the set of cookies
*/
private String getCookiesHeaderValue(Set<Cookie> cookies) {
StringBuilder cookieStr = new StringBuilder();
for (Cookie cookie : cookies) {
if (cookieStr.length() != 0) {
cookieStr.append("; ");
}
cookieStr.append(cookie.name()).append("=").append(cookie.value());
}
return cookieStr.toString();
}
// conversionWithGoodInputTest() helpers
/**
* Validates the various expected properties of the provided {@code nettyRequest}.
* @param nettyRequest the {@link NettyRequest} that needs to be validated.
* @param restMethod the expected {@link RestMethod} in {@code nettyRequest}.
* @param uri the expected URI in {@code nettyRequest}.
* @param headers the {@link HttpHeaders} passed with the request that need to be in {@link NettyRequest#getArgs()}.
* @param params the parameters passed with the request that need to be in {@link NettyRequest#getArgs()}.
* @param httpCookies Set of {@link Cookie} set in the request
* @param channel the {@link MockChannel} over which the request was received.
*/
private void validateRequest(NettyRequest nettyRequest, RestMethod restMethod, String uri, HttpHeaders headers,
Map<String, List<String>> params, Set<Cookie> httpCookies, MockChannel channel) {
long contentLength =
headers.contains(HttpHeaderNames.CONTENT_LENGTH) ? Long.parseLong(headers.get(HttpHeaderNames.CONTENT_LENGTH))
: 0;
assertTrue("Request channel is not open", nettyRequest.isOpen());
assertEquals("Mismatch in content length", contentLength, nettyRequest.getSize());
assertEquals("Mismatch in rest method", restMethod, nettyRequest.getRestMethod());
assertEquals("Mismatch in path", uri.substring(0, uri.indexOf("?")), nettyRequest.getPath());
assertEquals("Mismatch in uri", uri, nettyRequest.getUri());
assertNotNull("There should have been a RestRequestMetricsTracker", nettyRequest.getMetricsTracker());
assertFalse("Should not have been a multipart request", nettyRequest.isMultipart());
SSLSession sslSession = nettyRequest.getSSLSession();
if (channel.pipeline().get(SslHandler.class) == null) {
assertNull("Non-null SSLSession when pipeline does not contain an SslHandler", sslSession);
} else {
assertEquals("SSLSession does not match one from MockChannel", channel.getSSLEngine().getSession(), sslSession);
}
Set<javax.servlet.http.Cookie> actualCookies =
(Set<javax.servlet.http.Cookie>) nettyRequest.getArgs().get(RestUtils.Headers.COOKIE);
compareCookies(httpCookies, actualCookies);
Map<String, List<String>> receivedArgs = new HashMap<String, List<String>>();
for (Map.Entry<String, Object> e : nettyRequest.getArgs().entrySet()) {
if (!e.getKey().equalsIgnoreCase(HttpHeaderNames.COOKIE.toString())) {
if (!receivedArgs.containsKey(e.getKey())) {
receivedArgs.put(e.getKey(), new LinkedList<String>());
}
if (e.getValue() != null) {
List<String> values =
Arrays.asList(e.getValue().toString().split(NettyRequest.MULTIPLE_HEADER_VALUE_DELIMITER));
receivedArgs.get(e.getKey()).addAll(values);
}
}
}
Map<String, Integer> keyValueCount = new HashMap<String, Integer>();
for (Map.Entry<String, List<String>> param : params.entrySet()) {
assertTrue("Did not find key: " + param.getKey(), receivedArgs.containsKey(param.getKey()));
if (!keyValueCount.containsKey(param.getKey())) {
keyValueCount.put(param.getKey(), 0);
}
if (param.getValue() != null) {
boolean containsAllValues = receivedArgs.get(param.getKey()).containsAll(param.getValue());
assertTrue("Did not find all values expected for key: " + param.getKey(), containsAllValues);
keyValueCount.put(param.getKey(), keyValueCount.get(param.getKey()) + param.getValue().size());
}
}
for (Map.Entry<String, String> e : headers) {
if (!e.getKey().equalsIgnoreCase(HttpHeaderNames.COOKIE.toString())) {
assertTrue("Did not find key: " + e.getKey(), receivedArgs.containsKey(e.getKey()));
if (!keyValueCount.containsKey(e.getKey())) {
keyValueCount.put(e.getKey(), 0);
}
if (headers.get(e.getKey()) != null) {
assertTrue("Did not find value '" + e.getValue() + "' expected for key: '" + e.getKey() + "'",
receivedArgs.get(e.getKey()).contains(e.getValue()));
keyValueCount.put(e.getKey(), keyValueCount.get(e.getKey()) + 1);
}
}
}
assertEquals("Number of args does not match", keyValueCount.size(), receivedArgs.size());
for (Map.Entry<String, Integer> e : keyValueCount.entrySet()) {
assertEquals("Value count for key " + e.getKey() + " does not match", e.getValue().intValue(),
receivedArgs.get(e.getKey()).size());
}
assertEquals("Auto-read is in an invalid state",
(!restMethod.equals(RestMethod.POST) && !restMethod.equals(RestMethod.PUT))
|| NettyRequest.bufferWatermark <= 0, channel.config().isAutoRead());
}
/**
* Compares a set of HttpCookies {@link Cookie} with a set of Java Cookies {@link javax.servlet.http.Cookie} for
* equality in values
* @param expected Set of {@link Cookie}s to be compared with the {@code actual}
* @param actual Set of {@link javax.servlet.http.Cookie}s to be compared with those of {@code expected}
*/
private void compareCookies(Set<Cookie> expected, Set<javax.servlet.http.Cookie> actual) {
Assert.assertEquals("Size didn't match", expected.size(), actual.size());
HashMap<String, Cookie> expectedHashMap = new HashMap<String, Cookie>();
for (Cookie cookie : expected) {
expectedHashMap.put(cookie.name(), cookie);
}
for (javax.servlet.http.Cookie cookie : actual) {
Assert.assertEquals("Value field didn't match ", expectedHashMap.get(cookie.getName()).value(),
cookie.getValue());
}
}
// contentAddAndReadTest(), readIntoExceptionsTest() and backPressureTest() helpers
/**
* Tests {@link NettyRequest#addContent(HttpContent)} and
* {@link NettyRequest#readInto(AsyncWritableChannel, Callback)} by creating a {@link NettyRequest}, adding a few
* pieces of content to it and then reading from it to match the stream with the added content.
* <p/>
* The read happens at different points of time w.r.t content addition (before, during, after).
* @param digestAlgorithm the digest algorithm to use. Can be empty or {@code null} if digest checking is not
* required.
* @param useCopyForcingByteBuf if {@code true}, uses {@link CopyForcingByteBuf} instead of the default
* {@link ByteBuf}.
* @param method Http method
* @throws Exception
*/
private void contentAddAndReadTest(String digestAlgorithm, boolean useCopyForcingByteBuf, HttpMethod method)
throws Exception {
// non composite content
// start reading before addition of content
List<HttpContent> httpContents = new ArrayList<>();
ByteBuffer content = generateContent(httpContents, useCopyForcingByteBuf);
doContentAddAndReadTest(digestAlgorithm, content, httpContents, 0, method);
// start reading in the middle of content add
httpContents.clear();
content = generateContent(httpContents, useCopyForcingByteBuf);
doContentAddAndReadTest(digestAlgorithm, content, httpContents, httpContents.size() / 2, method);
// start reading after all content added
httpContents.clear();
content = generateContent(httpContents, useCopyForcingByteBuf);
doContentAddAndReadTest(digestAlgorithm, content, httpContents, httpContents.size(), method);
// composite content
httpContents.clear();
content = generateCompositeContent(httpContents);
doContentAddAndReadTest(digestAlgorithm, content, httpContents, 0, method);
}
/**
* Does the content addition and read verification based on the arguments provided.
* @param digestAlgorithm the digest algorithm to use. Can be empty or {@code null} if digest checking is not
* required.
* @param content the complete content.
* @param httpContents {@code content} in parts and as {@link HttpContent}. Should contain all the data in
* {@code content}.
* @param numChunksToAddBeforeRead the number of {@link HttpContent} to add before making the
* {@link NettyRequest#readInto(AsyncWritableChannel, Callback)} call.
* @param method Http method
* @throws Exception
*/
private void doContentAddAndReadTest(String digestAlgorithm, ByteBuffer content, List<HttpContent> httpContents,
int numChunksToAddBeforeRead, HttpMethod method) throws Exception {
if (numChunksToAddBeforeRead < 0 || numChunksToAddBeforeRead > httpContents.size()) {
throw new IllegalArgumentException("Illegal value of numChunksToAddBeforeRead");
}
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(method, "/", null, channel);
byte[] wholeDigest = null;
if (digestAlgorithm != null && !digestAlgorithm.isEmpty()) {
MessageDigest digest = MessageDigest.getInstance(digestAlgorithm);
digest.update(content);
wholeDigest = digest.digest();
content.rewind();
nettyRequest.setDigestAlgorithm(digestAlgorithm);
}
int bytesToVerify = 0;
int addedCount = 0;
for (; addedCount < numChunksToAddBeforeRead; addedCount++) {
HttpContent httpContent = httpContents.get(addedCount);
bytesToVerify += httpContent.content().readableBytes();
nettyRequest.addContent(httpContent);
// ref count always 2 when added before calling readInto()
assertEquals("Reference count is not as expected", 2, httpContent.refCnt());
}
ByteBufferAsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
Future<Long> future = nettyRequest.readInto(writeChannel, callback);
readAndVerify(bytesToVerify, writeChannel, content);
bytesToVerify = 0;
for (; addedCount < httpContents.size(); addedCount++) {
HttpContent httpContent = httpContents.get(addedCount);
bytesToVerify += httpContent.content().readableBytes();
nettyRequest.addContent(httpContent);
int expectedRefCountOnAdd = httpContent.content().nioBufferCount() > 0 ? 2 : 1;
assertEquals("Reference count is not as expected", expectedRefCountOnAdd, httpContent.refCnt());
}
readAndVerify(bytesToVerify, writeChannel, content);
verifyRefCnts(httpContents);
writeChannel.close();
callback.awaitCallback();
if (callback.exception != null) {
throw callback.exception;
}
long futureBytesRead = future.get();
assertEquals("Total bytes read does not match (callback)", content.limit(), callback.bytesRead);
assertEquals("Total bytes read does not match (future)", content.limit(), futureBytesRead);
// check twice to make sure the same digest is returned every time
for (int i = 0; i < 2; i++) {
assertArrayEquals("Part by part digest should match digest of whole", wholeDigest, nettyRequest.getDigest());
}
closeRequestAndValidate(nettyRequest, channel);
}
/**
* Tests backpressure support in {@link NettyRequest} for different values of {@link NettyRequest#bufferWatermark}.
* @param digestAlgorithm the digest algorithm to use. Can be empty or {@code null} if digest checking is not
* required.
* @param useCopyForcingByteBuf if {@code true}, uses {@link CopyForcingByteBuf} instead of the default
* {@link ByteBuf}.
* @param method Http method
* @throws Exception
*/
private void backPressureTest(String digestAlgorithm, boolean useCopyForcingByteBuf, HttpMethod method)
throws Exception {
List<HttpContent> httpContents = new ArrayList<HttpContent>();
byte[] contentBytes = TestUtils.getRandomBytes(GENERATED_CONTENT_SIZE);
ByteBuffer content = ByteBuffer.wrap(contentBytes);
splitContent(contentBytes, GENERATED_CONTENT_PART_COUNT, httpContents, useCopyForcingByteBuf);
int chunkSize = httpContents.get(0).content().readableBytes();
int[] bufferWatermarks = {1, chunkSize - 1, chunkSize,
chunkSize + 1, chunkSize * httpContents.size() / 2, content.limit() - 1, content.limit(), content.limit() + 1};
for (int bufferWatermark : bufferWatermarks) {
NettyRequest.bufferWatermark = bufferWatermark;
// start reading before addition of content
httpContents.clear();
content.rewind();
splitContent(contentBytes, GENERATED_CONTENT_PART_COUNT, httpContents, useCopyForcingByteBuf);
doBackPressureTest(digestAlgorithm, content, httpContents, 0, method);
// start reading in the middle of content add
httpContents.clear();
content.rewind();
splitContent(contentBytes, GENERATED_CONTENT_PART_COUNT, httpContents, useCopyForcingByteBuf);
doBackPressureTest(digestAlgorithm, content, httpContents, httpContents.size() / 2, method);
// start reading after all content added
httpContents.clear();
content.rewind();
splitContent(contentBytes, GENERATED_CONTENT_PART_COUNT, httpContents, useCopyForcingByteBuf);
doBackPressureTest(digestAlgorithm, content, httpContents, httpContents.size(), method);
}
}
/**
* Does the backpressure test by ensuring that {@link Channel#read()} isn't called when the number of bytes buffered
* is above the {@link NettyRequest#bufferWatermark}. Also ensures that {@link Channel#read()} is called correctly
* when the number of buffered bytes falls below the {@link NettyRequest#bufferWatermark}.
* @param digestAlgorithm the digest algorithm to use. Can be empty or {@code null} if digest checking is not
* required.
* @param content the complete content.
* @param httpContents {@code content} in parts and as {@link HttpContent}. Should contain all the data in
* {@code content}.
* @param numChunksToAddBeforeRead the number of {@link HttpContent} to add before making the
* {@link NettyRequest#readInto(AsyncWritableChannel, Callback)} call.
* @param method Http Method
* @throws Exception
*/
private void doBackPressureTest(String digestAlgorithm, ByteBuffer content, List<HttpContent> httpContents,
int numChunksToAddBeforeRead, HttpMethod method) throws Exception {
if (numChunksToAddBeforeRead < 0 || numChunksToAddBeforeRead > httpContents.size()) {
throw new IllegalArgumentException("Illegal value of numChunksToAddBeforeRead");
}
MockChannel channel = new MockChannel();
final NettyRequest nettyRequest = createNettyRequest(method, "/", null, channel);
byte[] wholeDigest = null;
if (digestAlgorithm != null && !digestAlgorithm.isEmpty()) {
MessageDigest digest = MessageDigest.getInstance(digestAlgorithm);
digest.update(content);
wholeDigest = digest.digest();
content.rewind();
nettyRequest.setDigestAlgorithm(digestAlgorithm);
}
final AtomicInteger queuedReads = new AtomicInteger(0);
ByteBufferAsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
channel.setChannelReadCallback(new MockChannel.ChannelReadCallback() {
@Override
public void onRead() {
queuedReads.incrementAndGet();
}
});
int addedCount = 0;
Future<Long> future = null;
boolean suspended = false;
int bytesToVerify = 0;
while (addedCount < httpContents.size()) {
if (suspended) {
assertEquals("There should have been no reads queued when over buffer watermark", 0, queuedReads.get());
if (future == null) {
future = nettyRequest.readInto(writeChannel, callback);
}
int chunksRead = readAndVerify(bytesToVerify, writeChannel, content);
assertEquals("Number of reads triggered is not as expected", chunksRead, queuedReads.get());
// collapse many reads into one
queuedReads.set(1);
bytesToVerify = 0;
suspended = false;
} else {
assertEquals("There should have been only one read queued", 1, queuedReads.get());
queuedReads.set(0);
if (future == null && addedCount == numChunksToAddBeforeRead) {
future = nettyRequest.readInto(writeChannel, callback);
}
final HttpContent httpContent = httpContents.get(addedCount);
bytesToVerify += (httpContent.content().readableBytes());
suspended = bytesToVerify >= NettyRequest.bufferWatermark;
addedCount++;
nettyRequest.addContent(httpContent);
int expectedRefCountOnAdd = future == null || httpContent.content().nioBufferCount() > 0 ? 2 : 1;
assertEquals("Reference count is not as expected", expectedRefCountOnAdd, httpContent.refCnt());
}
}
if (future == null) {
future = nettyRequest.readInto(writeChannel, callback);
}
readAndVerify(bytesToVerify, writeChannel, content);
verifyRefCnts(httpContents);
writeChannel.close();
callback.awaitCallback();
if (callback.exception != null) {
throw callback.exception;
}
long futureBytesRead = future.get(1, TimeUnit.SECONDS);
assertEquals("Total bytes read does not match (callback)", content.limit(), callback.bytesRead);
assertEquals("Total bytes read does not match (future)", content.limit(), futureBytesRead);
// check twice to make sure the same digest is returned every time
for (int i = 0; i < 2; i++) {
assertArrayEquals("Part by part digest should match digest of whole", wholeDigest, nettyRequest.getDigest());
}
closeRequestAndValidate(nettyRequest, channel);
}
/**
* Generates random content and fills it up (in parts) in {@code httpContents}.
* @param httpContents the {@link List<HttpContent>} that will contain all the content in parts.
* @return the whole content as a {@link ByteBuffer} - serves as a source of truth.
*/
private ByteBuffer generateContent(List<HttpContent> httpContents) {
return generateContent(httpContents, false);
}
/**
* Generates random content and fills it up (in parts) in {@code httpContents}.
* @param httpContents the {@link List<HttpContent>} that will contain all the content in parts.
* @param useCopyForcingByteBuf if {@code true}, uses {@link CopyForcingByteBuf} instead of the default
* {@link ByteBuf}.
* @return the whole content as a {@link ByteBuffer} - serves as a source of truth.
*/
private ByteBuffer generateContent(List<HttpContent> httpContents, boolean useCopyForcingByteBuf) {
byte[] contentBytes = TestUtils.getRandomBytes(GENERATED_CONTENT_SIZE);
splitContent(contentBytes, GENERATED_CONTENT_PART_COUNT, httpContents, useCopyForcingByteBuf);
return ByteBuffer.wrap(contentBytes);
}
/**
* Splits the given {@code contentBytes} into {@code numChunks} chunks and stores them in {@code httpContents}.
* @param contentBytes the content that needs to be split.
* @param numChunks the number of chunks to split {@code contentBytes} into.
* @param httpContents the {@link List<HttpContent>} that will contain all the content in parts.
* @param useCopyForcingByteBuf if {@code true}, uses {@link CopyForcingByteBuf} instead of the default
* {@link ByteBuf}.
*/
private void splitContent(byte[] contentBytes, int numChunks, List<HttpContent> httpContents,
boolean useCopyForcingByteBuf) {
int individualPartSize = contentBytes.length / numChunks;
ByteBuf content;
for (int addedContentCount = 0; addedContentCount < numChunks - 1; addedContentCount++) {
if (useCopyForcingByteBuf) {
content =
CopyForcingByteBuf.wrappedBuffer(contentBytes, addedContentCount * individualPartSize, individualPartSize);
} else {
content = Unpooled.wrappedBuffer(contentBytes, addedContentCount * individualPartSize, individualPartSize);
}
httpContents.add(new DefaultHttpContent(content));
}
if (useCopyForcingByteBuf) {
content =
CopyForcingByteBuf.wrappedBuffer(contentBytes, (numChunks - 1) * individualPartSize, individualPartSize);
} else {
content = Unpooled.wrappedBuffer(contentBytes, (numChunks - 1) * individualPartSize, individualPartSize);
}
httpContents.add(new DefaultLastHttpContent(content));
}
/**
* Generates random content and fills it up in {@code httpContents} with a backing {@link CompositeByteBuf}.
* @param httpContents the {@link List<HttpContent>} that will contain all the content.
* @return the whole content as a {@link ByteBuffer} - serves as a source of truth.
*/
private ByteBuffer generateCompositeContent(List<HttpContent> httpContents) {
int individualPartSize = GENERATED_CONTENT_SIZE / GENERATED_CONTENT_PART_COUNT;
byte[] contentBytes = TestUtils.getRandomBytes(GENERATED_CONTENT_SIZE);
ArrayList<ByteBuf> byteBufs = new ArrayList<>(GENERATED_CONTENT_PART_COUNT);
for (int addedContentCount = 0; addedContentCount < GENERATED_CONTENT_PART_COUNT; addedContentCount++) {
byteBufs.add(Unpooled.wrappedBuffer(contentBytes, addedContentCount * individualPartSize, individualPartSize));
}
httpContents.add(new DefaultLastHttpContent(new CompositeByteBuf(ByteBufAllocator.DEFAULT, false, 20, byteBufs)));
return ByteBuffer.wrap(contentBytes);
}
/**
* Verifies that the reference counts of {@code httpContents} is undisturbed after all operations.
* @param httpContents the {@link List<HttpContent>} of contents whose reference counts need to checked.
*/
private void verifyRefCnts(List<HttpContent> httpContents) {
for (HttpContent httpContent : httpContents) {
assertEquals("Reference count of http content has changed", 1, httpContent.refCnt());
}
}
/**
* Reads from the provided {@code writeChannel} and verifies that the bytes received match the original content
* provided through {@code content}.
* @param readLengthDesired desired length of bytes to read.
* @param writeChannel the {@link ByteBufferAsyncWritableChannel} to read from.
* @param content the original content that serves as the source of truth.
* @return the number of chunks read.
* @throws InterruptedException
*/
private int readAndVerify(int readLengthDesired, ByteBufferAsyncWritableChannel writeChannel, ByteBuffer content)
throws InterruptedException {
int bytesRead = 0;
int chunksRead = 0;
while (bytesRead < readLengthDesired) {
ByteBuffer recvdContent = writeChannel.getNextChunk();
while (recvdContent.hasRemaining()) {
assertEquals("Unexpected byte", content.get(), recvdContent.get());
bytesRead++;
}
writeChannel.resolveOldestChunk(null);
chunksRead++;
}
return chunksRead;
}
// headerAndContentSizeMismatchTest() helpers
/**
* Tests reaction of NettyRequest when content size is less than the size specified in the headers.
* @throws Exception
*/
private void sizeInHeaderMoreThanContentTest() throws Exception {
List<HttpContent> httpContents = new ArrayList<HttpContent>();
ByteBuffer content = generateContent(httpContents);
HttpHeaders httpHeaders = new DefaultHttpHeaders();
httpHeaders.set(HttpHeaderNames.CONTENT_LENGTH, content.limit() + 1);
doHeaderAndContentSizeMismatchTest(httpHeaders, httpContents);
}
/**
* Tests reaction of NettyRequest when content size is more than the size specified in the headers.
* @throws Exception
*/
private void sizeInHeaderLessThanContentTest() throws Exception {
List<HttpContent> httpContents = new ArrayList<HttpContent>();
ByteBuffer content = generateContent(httpContents);
HttpHeaders httpHeaders = new DefaultHttpHeaders();
int lastHttpContentSize = httpContents.get(httpContents.size() - 1).content().readableBytes();
httpHeaders.set(HttpHeaderNames.CONTENT_LENGTH, content.limit() - lastHttpContentSize - 1);
doHeaderAndContentSizeMismatchTest(httpHeaders, httpContents);
}
/**
* Tests reaction of NettyRequest when content size is different from the size specified in the headers.
* @param httpHeaders {@link HttpHeaders} that need to be a part of the request.
* @param httpContents the {@link List<HttpContent>} that needs to be added to {@code nettyRequest}.
* @throws Exception
*/
private void doHeaderAndContentSizeMismatchTest(HttpHeaders httpHeaders, List<HttpContent> httpContents)
throws Exception {
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", httpHeaders, channel);
AsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
Future<Long> future = nettyRequest.readInto(writeChannel, callback);
int bytesAdded = 0;
HttpContent httpContentToAdd = null;
for (HttpContent httpContent : httpContents) {
httpContentToAdd = httpContent;
int contentBytes = httpContentToAdd.content().readableBytes();
if (!(httpContentToAdd instanceof LastHttpContent) && (bytesAdded + contentBytes <= nettyRequest.getSize())) {
nettyRequest.addContent(httpContentToAdd);
assertEquals("Reference count is not as expected", 2, httpContentToAdd.refCnt());
bytesAdded += contentBytes;
} else {
break;
}
}
// the addition of the next content should throw an exception.
try {
nettyRequest.addContent(httpContentToAdd);
fail("Adding content should have failed because there was a mismatch in size");
} catch (RestServiceException e) {
assertEquals("Unexpected RestServiceErrorCode", RestServiceErrorCode.BadRequest, e.getErrorCode());
}
closeRequestAndValidate(nettyRequest, channel);
writeChannel.close();
verifyRefCnts(httpContents);
callback.awaitCallback();
assertNotNull("There should be a RestServiceException in the callback", callback.exception);
assertEquals("Unexpected RestServiceErrorCode", RestServiceErrorCode.BadRequest,
((RestServiceException) callback.exception).getErrorCode());
try {
future.get();
fail("Should have thrown exception because the future is expected to have been given one");
} catch (ExecutionException e) {
RestServiceException restServiceException = (RestServiceException) Utils.getRootCause(e);
assertNotNull("There should be a RestServiceException in the future", restServiceException);
assertEquals("Unexpected RestServiceErrorCode", RestServiceErrorCode.BadRequest,
restServiceException.getErrorCode());
}
}
// digestIncorrectUsageTest() helpers.
/**
* Tests for failure when {@link NettyRequest#setDigestAlgorithm(String)} after
* {@link NettyRequest#readInto(AsyncWritableChannel, Callback)} is called.
* @throws NoSuchAlgorithmException
* @throws RestServiceException
*/
private void setDigestAfterReadTest() throws NoSuchAlgorithmException, RestServiceException {
List<HttpContent> httpContents = new ArrayList<HttpContent>();
generateContent(httpContents);
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
ByteBufferAsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
nettyRequest.readInto(writeChannel, callback);
try {
nettyRequest.setDigestAlgorithm("MD5");
fail("Setting a digest algorithm should have failed because readInto() has already been called");
} catch (IllegalStateException e) {
// expected. Nothing to do.
}
writeChannel.close();
closeRequestAndValidate(nettyRequest, channel);
}
/**
* Tests for failure when {@link NettyRequest#setDigestAlgorithm(String)} is called with an unrecognized algorithm.
* @throws RestServiceException
*/
private void setBadAlgorithmTest() throws RestServiceException {
List<HttpContent> httpContents = new ArrayList<HttpContent>();
generateContent(httpContents);
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
try {
nettyRequest.setDigestAlgorithm("NonExistentAlgorithm");
fail("Setting a digest algorithm should have failed because the algorithm isn't valid");
} catch (NoSuchAlgorithmException e) {
// expected. Nothing to do.
}
closeRequestAndValidate(nettyRequest, channel);
}
/**
* Tests for failure when {@link NettyRequest#getDigest()} is called without a call to
* {@link NettyRequest#setDigestAlgorithm(String)}.
* @throws RestServiceException
*/
private void getDigestWithoutSettingAlgorithmTest() throws RestServiceException {
List<HttpContent> httpContents = new ArrayList<HttpContent>();
generateContent(httpContents);
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
ByteBufferAsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
nettyRequest.readInto(writeChannel, callback);
for (HttpContent httpContent : httpContents) {
nettyRequest.addContent(httpContent);
}
assertNull("Digest should be null because no digest algorithm was set", nettyRequest.getDigest());
closeRequestAndValidate(nettyRequest, channel);
}
/**
* Tests for failure when {@link NettyRequest#getDigest()} is called before
* 1. All content is added.
* 2. All content is processed (i.e. before a call to {@link NettyRequest#readInto(AsyncWritableChannel, Callback)}).
* @throws NoSuchAlgorithmException
* @throws RestServiceException
*/
private void getDigestBeforeAllContentProcessedTest() throws NoSuchAlgorithmException, RestServiceException {
// all content not added test.
List<HttpContent> httpContents = new ArrayList<HttpContent>();
generateContent(httpContents);
Channel channel = new MockChannel();
NettyRequest nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
nettyRequest.setDigestAlgorithm("MD5");
// add all except the LastHttpContent
for (int i = 0; i < httpContents.size() - 1; i++) {
nettyRequest.addContent(httpContents.get(i));
}
ByteBufferAsyncWritableChannel writeChannel = new ByteBufferAsyncWritableChannel();
ReadIntoCallback callback = new ReadIntoCallback();
nettyRequest.readInto(writeChannel, callback);
try {
nettyRequest.getDigest();
fail("Getting a digest should have failed because all the content has not been added");
} catch (IllegalStateException e) {
// expected. Nothing to do.
}
closeRequestAndValidate(nettyRequest, channel);
// content not processed test.
httpContents.clear();
generateContent(httpContents);
nettyRequest = createNettyRequest(HttpMethod.POST, "/", null, channel);
nettyRequest.setDigestAlgorithm("MD5");
for (HttpContent httpContent : httpContents) {
nettyRequest.addContent(httpContent);
}
try {
nettyRequest.getDigest();
fail("Getting a digest should have failed because the content has not been processed (readInto() not called)");
} catch (IllegalStateException e) {
// expected. Nothing to do.
}
closeRequestAndValidate(nettyRequest, channel);
}
}
/**
* Callback for all read operations on {@link NettyRequest}.
*/
class ReadIntoCallback implements Callback<Long> {
public volatile long bytesRead;
public volatile Exception exception;
private final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
private final CountDownLatch latch = new CountDownLatch(1);
@Override
public void onCompletion(Long result, Exception exception) {
if (callbackInvoked.compareAndSet(false, true)) {
bytesRead = result;
this.exception = exception;
latch.countDown();
} else {
this.exception = new IllegalStateException("Callback invoked more than once");
}
}
/**
* Waits for the callback to be received.
* @throws InterruptedException if there was any intteruption while waiting for the callback.
* @throws TimeoutException if the callback did not arrive within a particular timeout.
*/
public void awaitCallback() throws InterruptedException, TimeoutException {
if (!latch.await(1, TimeUnit.SECONDS)) {
throw new TimeoutException("Timed out waiting for callback to trigger");
}
}
}
/**
* Used to test for {@link NettyRequest} behavior when a {@link AsyncWritableChannel} throws exceptions.
*/
class BadAsyncWritableChannel implements AsyncWritableChannel {
private final Exception exceptionToThrow;
private final AtomicBoolean isOpen = new AtomicBoolean(true);
/**
* Creates an instance of BadAsyncWritableChannel that throws {@code exceptionToThrow} on write.
* @param exceptionToThrow the {@link Exception} to throw on write.
*/
public BadAsyncWritableChannel(Exception exceptionToThrow) {
this.exceptionToThrow = exceptionToThrow;
}
@Override
public Future<Long> write(ByteBuffer src, Callback<Long> callback) {
if (exceptionToThrow instanceof RuntimeException) {
throw (RuntimeException) exceptionToThrow;
} else {
return markFutureInvokeCallback(callback, 0, exceptionToThrow);
}
}
@Override
public boolean isOpen() {
return isOpen.get();
}
@Override
public void close() throws IOException {
isOpen.set(false);
}
/**
* Creates and marks a future as done and invoked the callback with paramaters {@code totalBytesWritten} and
* {@code Exception}.
* @param callback the {@link Callback} to invoke.
* @param totalBytesWritten the number of bytes successfully written.
* @param exception the {@link Exception} that occurred if any.
* @return the {@link Future} that will contain the result of the operation.
*/
private Future<Long> markFutureInvokeCallback(Callback<Long> callback, long totalBytesWritten, Exception exception) {
FutureResult<Long> futureResult = new FutureResult<Long>();
futureResult.done(totalBytesWritten, exception);
if (callback != null) {
callback.onCompletion(totalBytesWritten, exception);
}
return futureResult;
}
}
/**
* A mock channel that can be configured to run custom code on {@link Channel#read()}.
*/
class MockChannel extends EmbeddedChannel {
/**
* Interface to provide code that is to be executed on {@link Channel#read()}.
*/
public interface ChannelReadCallback {
/**
* This is called when {@link Channel#read()} is called. Should contain logic that needs to be executed on read.
*/
public void onRead();
}
private final ChannelConfig config = new DefaultChannelConfig(this);
private SSLEngine sslEngine = null;
private ChannelReadCallback channelReadCallback = null;
private int queuedOnReads = 0;
MockChannel() {
// sending a placeholder handler to avoid NPE. This is of no consequence and saves us from having to implement
// needless functions.
super(new ConnectionStatsHandler(new NettyMetrics(new MetricRegistry())));
}
/**
* Add an {@link SslHandler} to the pipeline (for testing {@link NettyRequest#getSSLSession()}.
* @throws SSLException
* @throws CertificateException
*/
MockChannel addSslHandlerToPipeline() throws SSLException, CertificateException {
if (pipeline().get(SslHandler.class) == null) {
SelfSignedCertificate ssc = new SelfSignedCertificate();
SslContext sslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build();
sslEngine = sslCtx.newEngine(alloc());
pipeline().addFirst(new SslHandler(sslEngine));
}
return this;
}
/**
* @return the {@link SSLEngine} associated with this channel, or {@code null} if no {@link SslHandler} is on this
* pipeline.
*/
SSLEngine getSSLEngine() {
return sslEngine;
}
/**
* Sets the {@link ChannelReadCallback}.
* @param channelReadCallback the {@link ChannelReadCallback} that will executed on {@link #read()}.
*/
void setChannelReadCallback(ChannelReadCallback channelReadCallback) {
this.channelReadCallback = channelReadCallback;
for (; queuedOnReads > 0; queuedOnReads--) {
channelReadCallback.onRead();
}
}
@Override
public ChannelConfig config() {
return config;
}
@Override
public Channel read() {
if (channelReadCallback != null) {
channelReadCallback.onRead();
} else {
queuedOnReads++;
}
return this;
}
}
/**
* An implementation of {@link ByteBuf} that forces {@link NettyRequest} to make a copy of the data.
*/
class CopyForcingByteBuf extends UnpooledHeapByteBuf {
private static final ByteBufAllocator ALLOC = UnpooledByteBufAllocator.DEFAULT;
/**
* Returns a {@link ByteBuf} that will not expose the underlying buffer through {@link #nioBuffer()} if {@code length}
* is greater than 0.
* @param array the backing byte array.
* @param offset the offset in the array from which the data is valid.
* @param length the length of data in the array from the {@code offset} which is valid.
* @return a {@link ByteBuf} that will not expose the underlying buffer through {@link #nioBuffer()} if {@code length}
* is greater than 0.
*/
public static ByteBuf wrappedBuffer(byte[] array, int offset, int length) {
if (length == 0) {
return Unpooled.EMPTY_BUFFER;
}
return new CopyForcingByteBuf(ALLOC, array, array.length).slice(offset, length);
}
protected CopyForcingByteBuf(ByteBufAllocator alloc, byte[] initialArray, int maxCapacity) {
super(alloc, initialArray, maxCapacity);
}
@Override
public int nioBufferCount() {
return -1;
}
}