/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec.http2;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalServerChannel;
import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown;
import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable;
import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.ByteArrayOutputStream;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR;
import static io.netty.handler.codec.http2.Http2TestUtil.randomString;
import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel;
import static io.netty.util.CharsetUtil.UTF_8;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.anyShort;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests the full HTTP/2 framing stack including the connection and preface handlers.
*/
public class Http2ConnectionRoundtripTest {
private static final long DEFAULT_AWAIT_TIMEOUT_SECONDS = 15;
@Mock
private Http2FrameListener clientListener;
@Mock
private Http2FrameListener serverListener;
private Http2ConnectionHandler http2Client;
private Http2ConnectionHandler http2Server;
private ServerBootstrap sb;
private Bootstrap cb;
private Channel serverChannel;
private volatile Channel serverConnectedChannel;
private Channel clientChannel;
private FrameCountDown serverFrameCountDown;
private CountDownLatch requestLatch;
private CountDownLatch serverSettingsAckLatch;
private CountDownLatch dataLatch;
private CountDownLatch trailersLatch;
private CountDownLatch goAwayLatch;
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
mockFlowControl(clientListener);
mockFlowControl(serverListener);
}
@After
public void teardown() throws Exception {
if (clientChannel != null) {
clientChannel.close().syncUninterruptibly();
clientChannel = null;
}
if (serverChannel != null) {
serverChannel.close().syncUninterruptibly();
serverChannel = null;
}
final Channel serverConnectedChannel = this.serverConnectedChannel;
if (serverConnectedChannel != null) {
serverConnectedChannel.close().syncUninterruptibly();
this.serverConnectedChannel = null;
}
Future<?> serverGroup = sb.config().group().shutdownGracefully(0, 5, SECONDS);
Future<?> serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 5, SECONDS);
Future<?> clientGroup = cb.config().group().shutdownGracefully(0, 5, SECONDS);
serverGroup.syncUninterruptibly();
serverChildGroup.syncUninterruptibly();
clientGroup.syncUninterruptibly();
}
@Test
public void inflightFrameAfterStreamResetShouldNotMakeConnectionUnusable() throws Exception {
bootstrapEnv(1, 1, 2, 1);
final CountDownLatch latch = new CountDownLatch(1);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
ChannelHandlerContext ctx = invocationOnMock.getArgument(0);
http2Server.encoder().writeHeaders(ctx,
(Integer) invocationOnMock.getArgument(1),
(Http2Headers) invocationOnMock.getArgument(2),
0,
false,
ctx.newPromise());
http2Server.flush(ctx);
return null;
}
}).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class),
anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
latch.countDown();
return null;
}
}).when(clientListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), any(Http2Headers.class),
anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean());
// Create a single stream by sending a HEADERS frame to the server.
final short weight = 16;
final Http2Headers headers = dummyHeaders();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, weight, false, 0, false, newPromise());
http2Client.flush(ctx());
http2Client.encoder().writeRstStream(ctx(), 3, Http2Error.INTERNAL_ERROR.code(), newPromise());
http2Client.flush(ctx());
}
});
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, weight, false, 0, false, newPromise());
http2Client.flush(ctx());
}
});
assertTrue(latch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
}
@Test
public void headersWithEndStreamShouldNotSendError() throws Exception {
bootstrapEnv(1, 1, 2, 1);
// Create a single stream by sending a HEADERS frame to the server.
final short weight = 16;
final Http2Headers headers = dummyHeaders();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, weight, false, 0, true,
newPromise());
http2Client.flush(ctx());
}
});
assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers),
eq(0), eq(weight), eq(false), eq(0), eq(true));
// Wait for some time to see if a go_away or reset frame will be received.
Thread.sleep(1000);
// Verify that no errors have been received.
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(),
anyLong(), any(ByteBuf.class));
verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(),
anyLong());
// The server will not respond, and so don't wait for graceful shutdown
http2Client.gracefulShutdownTimeoutMillis(0);
}
@Test
public void encodeViolatesMaxHeaderListSizeCanStillUseConnection() throws Exception {
bootstrapEnv(1, 2, 1, 0, 0);
final CountDownLatch serverSettingsAckLatch1 = new CountDownLatch(2);
final CountDownLatch serverSettingsAckLatch2 = new CountDownLatch(3);
final CountDownLatch clientSettingsLatch1 = new CountDownLatch(3);
final CountDownLatch serverRevHeadersLatch = new CountDownLatch(1);
final CountDownLatch clientHeadersLatch = new CountDownLatch(1);
final CountDownLatch clientDataWrite = new CountDownLatch(1);
final AtomicReference<Throwable> clientHeadersWriteException = new AtomicReference<Throwable>();
final AtomicReference<Throwable> clientHeadersWriteException2 = new AtomicReference<Throwable>();
final AtomicReference<Throwable> clientDataWriteException = new AtomicReference<Throwable>();
final Http2Headers headers = dummyHeaders();
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
serverSettingsAckLatch1.countDown();
serverSettingsAckLatch2.countDown();
return null;
}
}).when(serverListener).onSettingsAckRead(any(ChannelHandlerContext.class));
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
clientSettingsLatch1.countDown();
return null;
}
}).when(clientListener).onSettingsRead(any(ChannelHandlerContext.class), any(Http2Settings.class));
// Manually add a listener for when we receive the expected headers on the server.
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
serverRevHeadersLatch.countDown();
return null;
}
}).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), eq(headers),
anyInt(), anyShort(), anyBoolean(), eq(0), eq(true));
// Set the maxHeaderListSize to 100 so we may be able to write some headers, but not all. We want to verify
// that we don't corrupt state if some can be written but not all.
runInChannel(serverConnectedChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Server.encoder().writeSettings(serverCtx(),
new Http2Settings().copyFrom(http2Server.decoder().localSettings())
.maxHeaderListSize(100),
serverNewPromise());
http2Server.flush(serverCtx());
}
});
assertTrue(serverSettingsAckLatch1.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, false, newPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
clientHeadersWriteException.set(future.cause());
}
});
// It is expected that this write should fail locally and the remote peer will never see this.
http2Client.encoder().writeData(ctx(), 3, Unpooled.buffer(), 0, true, newPromise())
.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
clientDataWriteException.set(future.cause());
clientDataWrite.countDown();
}
});
http2Client.flush(ctx());
}
});
assertTrue(clientDataWrite.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertNotNull("Header encode should have exceeded maxHeaderListSize!", clientHeadersWriteException.get());
assertNotNull("Data on closed stream should fail!", clientDataWriteException.get());
// Set the maxHeaderListSize to the max value so we can send the headers.
runInChannel(serverConnectedChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Server.encoder().writeSettings(serverCtx(),
new Http2Settings().copyFrom(http2Server.decoder().localSettings())
.maxHeaderListSize(Http2CodecUtil.MAX_HEADER_LIST_SIZE),
serverNewPromise());
http2Server.flush(serverCtx());
}
});
assertTrue(clientSettingsLatch1.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(serverSettingsAckLatch2.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, true,
newPromise()).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
clientHeadersWriteException2.set(future.cause());
clientHeadersLatch.countDown();
}
});
http2Client.flush(ctx());
}
});
assertTrue(clientHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertNull("Client write of headers should succeed with increased header list size!",
clientHeadersWriteException2.get());
assertTrue(serverRevHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
verify(serverListener, never()).onDataRead(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class),
anyInt(), anyBoolean());
// Verify that no errors have been received.
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}
@Test
public void testSettingsAckIsSentBeforeUsingFlowControl() throws Exception {
bootstrapEnv(1, 1, 1, 1);
final CountDownLatch serverSettingsAckLatch1 = new CountDownLatch(1);
final CountDownLatch serverSettingsAckLatch2 = new CountDownLatch(2);
final CountDownLatch serverDataLatch = new CountDownLatch(1);
final CountDownLatch clientWriteDataLatch = new CountDownLatch(1);
final byte[] data = new byte[] {1, 2, 3, 4, 5};
final ByteArrayOutputStream out = new ByteArrayOutputStream(data.length);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
serverSettingsAckLatch1.countDown();
serverSettingsAckLatch2.countDown();
return null;
}
}).when(serverListener).onSettingsAckRead(any(ChannelHandlerContext.class));
doAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock in) throws Throwable {
ByteBuf buf = (ByteBuf) in.getArguments()[2];
int padding = (Integer) in.getArguments()[3];
int processedBytes = buf.readableBytes() + padding;
buf.readBytes(out, buf.readableBytes());
serverDataLatch.countDown();
return processedBytes;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), anyBoolean());
final Http2Headers headers = dummyHeaders();
// The server initially reduces the connection flow control window to 0.
runInChannel(serverConnectedChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Server.encoder().writeSettings(serverCtx(),
new Http2Settings().copyFrom(http2Server.decoder().localSettings())
.initialWindowSize(0),
serverNewPromise());
http2Server.flush(serverCtx());
}
});
assertTrue(serverSettingsAckLatch1.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
// The client should now attempt to send data, but the window size is 0 so it will be queued in the flow
// controller.
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false,
newPromise());
http2Client.encoder().writeData(ctx(), 3, Unpooled.wrappedBuffer(data), 0, true, newPromise());
http2Client.flush(ctx());
clientWriteDataLatch.countDown();
}
});
assertTrue(clientWriteDataLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
// Now the server opens up the connection window to allow the client to send the pending data.
runInChannel(serverConnectedChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Server.encoder().writeSettings(serverCtx(),
new Http2Settings().copyFrom(http2Server.decoder().localSettings())
.initialWindowSize(data.length),
serverNewPromise());
http2Server.flush(serverCtx());
}
});
assertTrue(serverSettingsAckLatch2.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(serverDataLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertArrayEquals(data, out.toByteArray());
// Verify that no errors have been received.
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}
@Test
public void priorityUsingHigherValuedStreamIdDoesNotPreventUsingLowerStreamId() throws Exception {
bootstrapEnv(1, 1, 2, 0);
final Http2Headers headers = dummyHeaders();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writePriority(ctx(), 5, 3, (short) 14, false, newPromise());
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false,
newPromise());
http2Client.flush(ctx());
}
});
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
verify(serverListener).onPriorityRead(any(ChannelHandlerContext.class), eq(5), eq(3), eq((short) 14),
eq(false));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0),
eq((short) 16), eq(false), eq(0), eq(false));
// Verify that no errors have been received.
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}
@Test
public void headersUsingHigherValuedStreamIdPreventsUsingLowerStreamId() throws Exception {
bootstrapEnv(1, 1, 1, 0);
final Http2Headers headers = dummyHeaders();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, (short) 16, false, 0, false,
newPromise());
http2Client.encoder().frameWriter().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false,
newPromise());
http2Client.flush(ctx());
}
});
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), eq(headers), eq(0),
eq((short) 16), eq(false), eq(0), eq(false));
verify(serverListener, never()).onHeadersRead(any(ChannelHandlerContext.class), eq(3), any(Http2Headers.class),
anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean());
// Client should receive a RST_STREAM for stream 3, but there is not Http2Stream object so the listener is never
// notified.
verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(),
any(ByteBuf.class));
verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong());
}
@Test
public void http2ExceptionInPipelineShouldCloseConnection() throws Exception {
bootstrapEnv(1, 1, 2, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
clientChannel.closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
closeLatch.countDown();
}
});
// Create a single stream by sending a HEADERS frame to the server.
final Http2Headers headers = dummyHeaders();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false,
newPromise());
http2Client.flush(ctx());
}
});
// Wait for the server to create the stream.
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
// Add a handler that will immediately throw an exception.
clientChannel.pipeline().addFirst(new ChannelHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
throw Http2Exception.connectionError(PROTOCOL_ERROR, "Fake Exception");
}
});
// Wait for the close to occur.
assertTrue(closeLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertFalse(clientChannel.isOpen());
}
@Test
public void listenerExceptionShouldCloseConnection() throws Exception {
final Http2Headers headers = dummyHeaders();
doThrow(new RuntimeException("Fake Exception")).when(serverListener).onHeadersRead(
any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), eq((short) 16),
eq(false), eq(0), eq(false));
bootstrapEnv(1, 0, 1, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
clientChannel.closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
closeLatch.countDown();
}
});
// Create a single stream by sending a HEADERS frame to the server.
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false,
newPromise());
http2Client.flush(ctx());
}
});
// Wait for the server to create the stream.
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
// Wait for the close to occur.
assertTrue(closeLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertFalse(clientChannel.isOpen());
}
@Test
public void nonHttp2ExceptionInPipelineShouldNotCloseConnection() throws Exception {
bootstrapEnv(1, 1, 2, 1);
// Create a latch to track when the close occurs.
final CountDownLatch closeLatch = new CountDownLatch(1);
clientChannel.closeFuture().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
closeLatch.countDown();
}
});
// Create a single stream by sending a HEADERS frame to the server.
final Http2Headers headers = dummyHeaders();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false,
newPromise());
http2Client.flush(ctx());
}
});
// Wait for the server to create the stream.
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
// Add a handler that will immediately throw an exception.
clientChannel.pipeline().addFirst(new ChannelHandlerAdapter() {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
throw new RuntimeException("Fake Exception");
}
});
// The close should NOT occur.
assertFalse(closeLatch.await(2, SECONDS));
assertTrue(clientChannel.isOpen());
// Set the timeout very low because we know graceful shutdown won't complete
http2Client.gracefulShutdownTimeoutMillis(0);
}
@Test
public void noMoreStreamIdsShouldSendGoAway() throws Exception {
bootstrapEnv(1, 1, 3, 1, 1);
// Don't wait for the server to close streams
http2Client.gracefulShutdownTimeoutMillis(0);
// Create a single stream by sending a HEADERS frame to the server.
final Http2Headers headers = dummyHeaders();
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0,
true, newPromise());
http2Client.flush(ctx());
}
});
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), Integer.MAX_VALUE + 1, headers, 0, (short) 16, false, 0,
true, newPromise());
http2Client.flush(ctx());
}
});
assertTrue(goAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
verify(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0),
eq(PROTOCOL_ERROR.code()), any(ByteBuf.class));
}
@Test
public void flowControlProperlyChunksLargeMessage() throws Exception {
final Http2Headers headers = dummyHeaders();
// Create a large message to send.
final int length = 10485760; // 10MB
// Create a buffer filled with random bytes.
final ByteBuf data = randomBytes(length);
final ByteArrayOutputStream out = new ByteArrayOutputStream(length);
doAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock in) throws Throwable {
ByteBuf buf = (ByteBuf) in.getArguments()[2];
int padding = (Integer) in.getArguments()[3];
int processedBytes = buf.readableBytes() + padding;
buf.readBytes(out, buf.readableBytes());
return processedBytes;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3),
any(ByteBuf.class), eq(0), anyBoolean());
try {
// Initialize the data latch based on the number of bytes expected.
bootstrapEnv(length, 1, 2, 1);
// Create the stream and send all of the data at once.
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0,
false, newPromise());
http2Client.encoder().writeData(ctx(), 3, data.retainedDuplicate(), 0, false, newPromise());
// Write trailers.
http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0,
true, newPromise());
http2Client.flush(ctx());
}
});
// Wait for the trailers to be received.
assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
assertTrue(trailersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
// Verify that headers and trailers were received.
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0),
eq((short) 16), eq(false), eq(0), eq(false));
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0),
eq((short) 16), eq(false), eq(0), eq(true));
// Verify we received all the bytes.
assertEquals(0, dataLatch.getCount());
out.flush();
byte[] received = out.toByteArray();
assertArrayEquals(data.array(), received);
} finally {
// Don't wait for server to close streams
http2Client.gracefulShutdownTimeoutMillis(0);
data.release();
out.close();
}
}
@Test
public void stressTest() throws Exception {
final Http2Headers headers = dummyHeaders();
final String pingMsg = "12345678";
int length = 10;
final ByteBuf data = randomBytes(length);
final String dataAsHex = ByteBufUtil.hexDump(data);
final ByteBuf pingData = Unpooled.copiedBuffer(pingMsg, UTF_8);
final int numStreams = 2000;
// Collect all the ping buffers as we receive them at the server.
final String[] receivedPings = new String[numStreams];
doAnswer(new Answer<Void>() {
int nextIndex;
@Override
public Void answer(InvocationOnMock in) throws Throwable {
receivedPings[nextIndex++] = ((ByteBuf) in.getArguments()[1]).toString(UTF_8);
return null;
}
}).when(serverListener).onPingRead(any(ChannelHandlerContext.class), any(ByteBuf.class));
// Collect all the data buffers as we receive them at the server.
final StringBuilder[] receivedData = new StringBuilder[numStreams];
doAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock in) throws Throwable {
int streamId = (Integer) in.getArguments()[1];
ByteBuf buf = (ByteBuf) in.getArguments()[2];
int padding = (Integer) in.getArguments()[3];
int processedBytes = buf.readableBytes() + padding;
int streamIndex = (streamId - 3) / 2;
StringBuilder builder = receivedData[streamIndex];
if (builder == null) {
builder = new StringBuilder(dataAsHex.length());
receivedData[streamIndex] = builder;
}
builder.append(ByteBufUtil.hexDump(buf));
return processedBytes;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(),
any(ByteBuf.class), anyInt(), anyBoolean());
try {
bootstrapEnv(numStreams * length, 1, numStreams * 4, numStreams);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
int upperLimit = 3 + 2 * numStreams;
for (int streamId = 3; streamId < upperLimit; streamId += 2) {
// Send a bunch of data on each stream.
http2Client.encoder().writeHeaders(ctx(), streamId, headers, 0, (short) 16,
false, 0, false, newPromise());
http2Client.encoder().writePing(ctx(), false, pingData.retainedSlice(),
newPromise());
http2Client.encoder().writeData(ctx(), streamId, data.retainedSlice(), 0,
false, newPromise());
// Write trailers.
http2Client.encoder().writeHeaders(ctx(), streamId, headers, 0, (short) 16,
false, 0, true, newPromise());
http2Client.flush(ctx());
}
}
});
// Wait for all frames to be received.
assertTrue(serverSettingsAckLatch.await(60, SECONDS));
assertTrue(trailersLatch.await(60, SECONDS));
verify(serverListener, times(numStreams)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(false));
verify(serverListener, times(numStreams)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(),
eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(true));
verify(serverListener, times(numStreams)).onPingRead(any(ChannelHandlerContext.class),
any(ByteBuf.class));
verify(serverListener, never()).onDataRead(any(ChannelHandlerContext.class),
anyInt(), any(ByteBuf.class), eq(0), eq(true));
for (StringBuilder builder : receivedData) {
assertEquals(dataAsHex, builder.toString());
}
for (String receivedPing : receivedPings) {
assertEquals(pingMsg, receivedPing);
}
} finally {
// Don't wait for server to close streams
http2Client.gracefulShutdownTimeoutMillis(0);
data.release();
pingData.release();
}
}
private void bootstrapEnv(int dataCountDown, int settingsAckCount,
int requestCountDown, int trailersCountDown) throws Exception {
bootstrapEnv(dataCountDown, settingsAckCount, requestCountDown, trailersCountDown, -1);
}
private void bootstrapEnv(int dataCountDown, int settingsAckCount,
int requestCountDown, int trailersCountDown, int goAwayCountDown) throws Exception {
final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1);
requestLatch = new CountDownLatch(requestCountDown);
serverSettingsAckLatch = new CountDownLatch(settingsAckCount);
dataLatch = new CountDownLatch(dataCountDown);
trailersLatch = new CountDownLatch(trailersCountDown);
goAwayLatch = goAwayCountDown > 0 ? new CountDownLatch(goAwayCountDown) : requestLatch;
sb = new ServerBootstrap();
cb = new Bootstrap();
final AtomicReference<Http2ConnectionHandler> serverHandlerRef = new AtomicReference<Http2ConnectionHandler>();
final CountDownLatch serverInitLatch = new CountDownLatch(1);
sb.group(new DefaultEventLoopGroup());
sb.channel(LocalServerChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
serverConnectedChannel = ch;
ChannelPipeline p = ch.pipeline();
serverFrameCountDown =
new FrameCountDown(serverListener, serverSettingsAckLatch,
requestLatch, dataLatch, trailersLatch, goAwayLatch);
serverHandlerRef.set(new Http2ConnectionHandlerBuilder()
.server(true)
.frameListener(serverFrameCountDown)
.validateHeaders(false)
.build());
p.addLast(serverHandlerRef.get());
serverInitLatch.countDown();
}
});
cb.group(new DefaultEventLoopGroup());
cb.channel(LocalChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast(new Http2ConnectionHandlerBuilder()
.server(false)
.frameListener(clientListener)
.validateHeaders(false)
.gracefulShutdownTimeoutMillis(0)
.build());
p.addLast(new ChannelInboundHandlerAdapter() {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof Http2ConnectionPrefaceWrittenEvent) {
prefaceWrittenLatch.countDown();
ctx.pipeline().remove(this);
}
}
});
}
});
serverChannel = sb.bind(new LocalAddress("Http2ConnectionRoundtripTest")).sync().channel();
ChannelFuture ccf = cb.connect(serverChannel.localAddress());
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
assertTrue(prefaceWrittenLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
http2Client = clientChannel.pipeline().get(Http2ConnectionHandler.class);
assertTrue(serverInitLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS));
http2Server = serverHandlerRef.get();
}
private ChannelHandlerContext ctx() {
return clientChannel.pipeline().firstContext();
}
private ChannelHandlerContext serverCtx() {
return serverConnectedChannel.pipeline().firstContext();
}
private ChannelPromise newPromise() {
return ctx().newPromise();
}
private ChannelPromise serverNewPromise() {
return serverCtx().newPromise();
}
private static Http2Headers dummyHeaders() {
return new DefaultHttp2Headers(false).method(new AsciiString("GET")).scheme(new AsciiString("https"))
.authority(new AsciiString("example.org")).path(new AsciiString("/some/path/resource2"))
.add(randomString(), randomString());
}
private static void mockFlowControl(Http2FrameListener listener) throws Http2Exception {
doAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock invocation) throws Throwable {
ByteBuf buf = (ByteBuf) invocation.getArguments()[2];
int padding = (Integer) invocation.getArguments()[3];
int processedBytes = buf.readableBytes() + padding;
return processedBytes;
}
}).when(listener).onDataRead(any(ChannelHandlerContext.class), anyInt(),
any(ByteBuf.class), anyInt(), anyBoolean());
}
/**
* Creates a {@link ByteBuf} of the given length, filled with random bytes.
*/
private static ByteBuf randomBytes(int length) {
final byte[] bytes = new byte[length];
new Random().nextBytes(bytes);
return Unpooled.wrappedBuffer(bytes);
}
}