/* * Copyright 2014 The Netty Project * * The Netty Project licenses this file to you under the Apache License, version 2.0 (the * "License"); you may not use this file except in compliance with the License. You may obtain a * copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package io.netty.handler.codec.http2; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.EmptyByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; import io.netty.util.AsciiString; import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.util.LinkedList; import java.util.List; import static io.netty.buffer.Unpooled.EMPTY_BUFFER; import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_PADDING; import static io.netty.handler.codec.http2.Http2HeadersEncoder.NEVER_SENSITIVE; import static io.netty.handler.codec.http2.Http2TestUtil.newTestDecoder; import static io.netty.handler.codec.http2.Http2TestUtil.newTestEncoder; import static io.netty.handler.codec.http2.Http2TestUtil.randomString; import static io.netty.util.CharsetUtil.UTF_8; import static java.lang.Math.min; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.fail; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyBoolean; import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyShort; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.isA; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** * Tests encoding/decoding each HTTP2 frame type. */ public class Http2FrameRoundtripTest { private static final byte[] MESSAGE = "hello world".getBytes(UTF_8); private static final int STREAM_ID = 0x7FFFFFFF; private static final int WINDOW_UPDATE = 0x7FFFFFFF; private static final long ERROR_CODE = 0xFFFFFFFFL; @Mock private Http2FrameListener listener; @Mock private ChannelHandlerContext ctx; @Mock private EventExecutor executor; @Mock private Channel channel; @Mock private ByteBufAllocator alloc; private Http2FrameWriter writer; private Http2FrameReader reader; private final List<ByteBuf> needReleasing = new LinkedList<ByteBuf>(); @Before public void setup() throws Exception { MockitoAnnotations.initMocks(this); when(ctx.alloc()).thenReturn(alloc); when(ctx.executor()).thenReturn(executor); when(ctx.channel()).thenReturn(channel); doAnswer(new Answer<ByteBuf>() { @Override public ByteBuf answer(InvocationOnMock in) throws Throwable { return Unpooled.buffer(); } }).when(alloc).buffer(); doAnswer(new Answer<ByteBuf>() { @Override public ByteBuf answer(InvocationOnMock in) throws Throwable { return Unpooled.buffer((Integer) in.getArguments()[0]); } }).when(alloc).buffer(anyInt()); doAnswer(new Answer<ChannelPromise>() { @Override public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { return new DefaultChannelPromise(channel, GlobalEventExecutor.INSTANCE); } }).when(ctx).newPromise(); writer = new DefaultHttp2FrameWriter(new DefaultHttp2HeadersEncoder(NEVER_SENSITIVE, newTestEncoder())); reader = new DefaultHttp2FrameReader(new DefaultHttp2HeadersDecoder(false, newTestDecoder())); } @After public void teardown() { try { // Release all of the buffers. for (ByteBuf buf : needReleasing) { buf.release(); } // Now verify that all of the reference counts are zero. for (ByteBuf buf : needReleasing) { int expectedFinalRefCount = 0; if (buf.isReadOnly() || buf instanceof EmptyByteBuf) { // Special case for when we're writing slices of the padding buffer. expectedFinalRefCount = 1; } assertEquals(expectedFinalRefCount, buf.refCnt()); } } finally { needReleasing.clear(); } } @Test public void emptyDataShouldMatch() throws Exception { final ByteBuf data = EMPTY_BUFFER; writer.writeData(ctx, STREAM_ID, data.slice(), 0, false, ctx.newPromise()); readFrames(); verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(0), eq(false)); } @Test public void dataShouldMatch() throws Exception { final ByteBuf data = data(10); writer.writeData(ctx, STREAM_ID, data.slice(), 1, false, ctx.newPromise()); readFrames(); verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(1), eq(false)); } @Test public void dataWithPaddingShouldMatch() throws Exception { final ByteBuf data = data(10); writer.writeData(ctx, STREAM_ID, data.slice(), MAX_PADDING, true, ctx.newPromise()); readFrames(); verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(MAX_PADDING), eq(true)); } @Test public void largeDataFrameShouldMatch() throws Exception { // Create a large message to force chunking. final ByteBuf originalData = data(1024 * 1024); final int originalPadding = 100; final boolean endOfStream = true; writer.writeData(ctx, STREAM_ID, originalData.slice(), originalPadding, endOfStream, ctx.newPromise()); readFrames(); // Verify that at least one frame was sent with eos=false and exactly one with eos=true. verify(listener, atLeastOnce()).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class), anyInt(), eq(false)); verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class), anyInt(), eq(true)); // Capture the read data and padding. ArgumentCaptor<ByteBuf> dataCaptor = ArgumentCaptor.forClass(ByteBuf.class); ArgumentCaptor<Integer> paddingCaptor = ArgumentCaptor.forClass(Integer.class); verify(listener, atLeastOnce()).onDataRead(eq(ctx), eq(STREAM_ID), dataCaptor.capture(), paddingCaptor.capture(), anyBoolean()); // Make sure the data matches the original. for (ByteBuf chunk : dataCaptor.getAllValues()) { ByteBuf originalChunk = originalData.readSlice(chunk.readableBytes()); assertEquals(originalChunk, chunk); } assertFalse(originalData.isReadable()); // Make sure the padding matches the original. int totalReadPadding = 0; for (int framePadding : paddingCaptor.getAllValues()) { totalReadPadding += framePadding; } assertEquals(originalPadding, totalReadPadding); } @Test public void emptyHeadersShouldMatch() throws Exception { final Http2Headers headers = EmptyHttp2Headers.INSTANCE; writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true)); } @Test public void emptyHeadersWithPaddingShouldMatch() throws Exception { final Http2Headers headers = EmptyHttp2Headers.INSTANCE; writer.writeHeaders(ctx, STREAM_ID, headers, MAX_PADDING, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(MAX_PADDING), eq(true)); } @Test public void binaryHeadersWithoutPriorityShouldMatch() throws Exception { final Http2Headers headers = binaryHeaders(); writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true)); } @Test public void headersFrameWithoutPriorityShouldMatch() throws Exception { final Http2Headers headers = headers(); writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true)); } @Test public void headersFrameWithPriorityShouldMatch() throws Exception { final Http2Headers headers = headers(); writer.writeHeaders(ctx, STREAM_ID, headers, 4, (short) 255, true, 0, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(4), eq((short) 255), eq(true), eq(0), eq(true)); } @Test public void headersWithPaddingWithoutPriorityShouldMatch() throws Exception { final Http2Headers headers = headers(); writer.writeHeaders(ctx, STREAM_ID, headers, MAX_PADDING, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(MAX_PADDING), eq(true)); } @Test public void headersWithPaddingWithPriorityShouldMatch() throws Exception { final Http2Headers headers = headers(); writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 1, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true), eq(1), eq(true)); } @Test public void continuedHeadersShouldMatch() throws Exception { final Http2Headers headers = largeHeaders(); writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 0, true, ctx.newPromise()); readFrames(); verify(listener) .onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true), eq(0), eq(true)); } @Test public void continuedHeadersWithPaddingShouldMatch() throws Exception { final Http2Headers headers = largeHeaders(); writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, MAX_PADDING, true, ctx.newPromise()); readFrames(); verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true), eq(MAX_PADDING), eq(true)); } @Test public void headersThatAreTooBigShouldFail() throws Exception { reader = new DefaultHttp2FrameReader(false); final int maxListSize = 100; reader.configuration().headersConfiguration().maxHeaderListSize(maxListSize, maxListSize); final Http2Headers headers = headersOfSize(maxListSize + 1); writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, MAX_PADDING, true, ctx.newPromise()); try { readFrames(); fail(); } catch (Http2Exception e) { verify(listener, never()).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); } } @Test public void emptyPushPromiseShouldMatch() throws Exception { final Http2Headers headers = EmptyHttp2Headers.INSTANCE; writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0, ctx.newPromise()); readFrames(); verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0)); } @Test public void pushPromiseFrameShouldMatch() throws Exception { final Http2Headers headers = headers(); writer.writePushPromise(ctx, STREAM_ID, 1, headers, 5, ctx.newPromise()); readFrames(); verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(1), eq(headers), eq(5)); } @Test public void pushPromiseWithPaddingShouldMatch() throws Exception { final Http2Headers headers = headers(); writer.writePushPromise(ctx, STREAM_ID, 2, headers, MAX_PADDING, ctx.newPromise()); readFrames(); verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(MAX_PADDING)); } @Test public void continuedPushPromiseShouldMatch() throws Exception { final Http2Headers headers = largeHeaders(); writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0, ctx.newPromise()); readFrames(); verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0)); } @Test public void continuedPushPromiseWithPaddingShouldMatch() throws Exception { final Http2Headers headers = largeHeaders(); writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0xFF, ctx.newPromise()); readFrames(); verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0xFF)); } @Test public void goAwayFrameShouldMatch() throws Exception { final String text = "test"; final ByteBuf data = buf(text.getBytes()); writer.writeGoAway(ctx, STREAM_ID, ERROR_CODE, data.slice(), ctx.newPromise()); readFrames(); ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); verify(listener).onGoAwayRead(eq(ctx), eq(STREAM_ID), eq(ERROR_CODE), captor.capture()); assertEquals(data, captor.getValue()); } @Test public void pingFrameShouldMatch() throws Exception { final ByteBuf data = buf("01234567".getBytes(UTF_8)); writer.writePing(ctx, false, data.slice(), ctx.newPromise()); readFrames(); ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); verify(listener).onPingRead(eq(ctx), captor.capture()); assertEquals(data, captor.getValue()); } @Test public void pingAckFrameShouldMatch() throws Exception { final ByteBuf data = buf("01234567".getBytes(UTF_8)); writer.writePing(ctx, true, data.slice(), ctx.newPromise()); readFrames(); ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); verify(listener).onPingAckRead(eq(ctx), captor.capture()); assertEquals(data, captor.getValue()); } @Test public void priorityFrameShouldMatch() throws Exception { writer.writePriority(ctx, STREAM_ID, 1, (short) 1, true, ctx.newPromise()); readFrames(); verify(listener).onPriorityRead(eq(ctx), eq(STREAM_ID), eq(1), eq((short) 1), eq(true)); } @Test public void rstStreamFrameShouldMatch() throws Exception { writer.writeRstStream(ctx, STREAM_ID, ERROR_CODE, ctx.newPromise()); readFrames(); verify(listener).onRstStreamRead(eq(ctx), eq(STREAM_ID), eq(ERROR_CODE)); } @Test public void emptySettingsFrameShouldMatch() throws Exception { final Http2Settings settings = new Http2Settings(); writer.writeSettings(ctx, settings, ctx.newPromise()); readFrames(); verify(listener).onSettingsRead(eq(ctx), eq(settings)); } @Test public void settingsShouldStripShouldMatch() throws Exception { final Http2Settings settings = new Http2Settings(); settings.pushEnabled(true); settings.headerTableSize(4096); settings.initialWindowSize(123); settings.maxConcurrentStreams(456); writer.writeSettings(ctx, settings, ctx.newPromise()); readFrames(); verify(listener).onSettingsRead(eq(ctx), eq(settings)); } @Test public void settingsAckShouldMatch() throws Exception { writer.writeSettingsAck(ctx, ctx.newPromise()); readFrames(); verify(listener).onSettingsAckRead(eq(ctx)); } @Test public void windowUpdateFrameShouldMatch() throws Exception { writer.writeWindowUpdate(ctx, STREAM_ID, WINDOW_UPDATE, ctx.newPromise()); readFrames(); verify(listener).onWindowUpdateRead(eq(ctx), eq(STREAM_ID), eq(WINDOW_UPDATE)); } private void readFrames() throws Http2Exception { // Now read all of the written frames. ByteBuf write = captureWrites(); reader.readFrame(ctx, write, listener); } private static ByteBuf data(int size) { byte[] data = new byte[size]; for (int ix = 0; ix < data.length;) { int length = min(MESSAGE.length, data.length - ix); System.arraycopy(MESSAGE, 0, data, ix, length); ix += length; } return buf(data); } private static ByteBuf buf(byte[] bytes) { return Unpooled.wrappedBuffer(bytes); } private <T extends ByteBuf> T releaseLater(T buf) { needReleasing.add(buf); return buf; } private ByteBuf captureWrites() { ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); verify(ctx, atLeastOnce()).write(captor.capture(), isA(ChannelPromise.class)); CompositeByteBuf composite = releaseLater(Unpooled.compositeBuffer()); for (ByteBuf buf : captor.getAllValues()) { buf = releaseLater(buf.retain()); composite.addComponent(true, buf); } return composite; } private static Http2Headers headers() { return new DefaultHttp2Headers(false).method(AsciiString.of("GET")).scheme(AsciiString.of("https")) .authority(AsciiString.of("example.org")).path(AsciiString.of("/some/path/resource2")) .add(randomString(), randomString()); } private static Http2Headers largeHeaders() { DefaultHttp2Headers headers = new DefaultHttp2Headers(false); for (int i = 0; i < 100; ++i) { String key = "this-is-a-test-header-key-" + i; String value = "this-is-a-test-header-value-" + i; headers.add(AsciiString.of(key), AsciiString.of(value)); } return headers; } private static Http2Headers headersOfSize(final int minSize) { final AsciiString singleByte = new AsciiString(new byte[]{0}, false); DefaultHttp2Headers headers = new DefaultHttp2Headers(false); for (int size = 0; size < minSize; size += 2) { headers.add(singleByte, singleByte); } return headers; } private static Http2Headers binaryHeaders() { DefaultHttp2Headers headers = new DefaultHttp2Headers(false); for (int ix = 0; ix < 10; ++ix) { headers.add(randomString(), randomString()); } return headers; } }