/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec.http2;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.AsciiString;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import junit.framework.AssertionFailedError;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE;
import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE;
/**
* Utilities for the integration tests.
*/
public final class Http2TestUtil {
/**
* Interface that allows for running a operation that throws a {@link Http2Exception}.
*/
interface Http2Runnable {
void run() throws Http2Exception;
}
/**
* Runs the given operation within the event loop thread of the given {@link Channel}.
*/
static void runInChannel(Channel channel, final Http2Runnable runnable) {
channel.eventLoop().execute(new Runnable() {
@Override
public void run() {
try {
runnable.run();
} catch (Http2Exception e) {
throw new RuntimeException(e);
}
}
});
}
/**
* Returns a byte array filled with random data.
*/
public static byte[] randomBytes() {
return randomBytes(100);
}
/**
* Returns a byte array filled with random data.
*/
public static byte[] randomBytes(int size) {
byte[] data = new byte[size];
new Random().nextBytes(data);
return data;
}
/**
* Returns an {@link AsciiString} that wraps a randomly-filled byte array.
*/
public static AsciiString randomString() {
return new AsciiString(randomBytes());
}
public static CharSequence of(String s) {
return s;
}
public static HpackEncoder newTestEncoder() {
try {
return newTestEncoder(true, MAX_HEADER_LIST_SIZE, MAX_HEADER_TABLE_SIZE);
} catch (Http2Exception e) {
throw new Error("max size not allowed?", e);
}
}
public static HpackEncoder newTestEncoder(boolean ignoreMaxHeaderListSize,
long maxHeaderListSize, long maxHeaderTableSize) throws Http2Exception {
HpackEncoder hpackEncoder = new HpackEncoder();
ByteBuf buf = Unpooled.buffer();
try {
hpackEncoder.setMaxHeaderTableSize(buf, maxHeaderTableSize);
hpackEncoder.setMaxHeaderListSize(maxHeaderListSize);
} finally {
buf.release();
}
return hpackEncoder;
}
public static HpackDecoder newTestDecoder() {
try {
return newTestDecoder(MAX_HEADER_LIST_SIZE, MAX_HEADER_TABLE_SIZE);
} catch (Http2Exception e) {
throw new Error("max size not allowed?", e);
}
}
public static HpackDecoder newTestDecoder(long maxHeaderListSize, long maxHeaderTableSize) throws Http2Exception {
HpackDecoder hpackDecoder = new HpackDecoder(maxHeaderListSize, 32);
hpackDecoder.setMaxHeaderTableSize(maxHeaderTableSize);
return hpackDecoder;
}
private Http2TestUtil() {
}
static class FrameAdapter extends ByteToMessageDecoder {
private final Http2Connection connection;
private final Http2FrameListener listener;
private final DefaultHttp2FrameReader reader;
private final CountDownLatch latch;
FrameAdapter(Http2FrameListener listener, CountDownLatch latch) {
this(null, listener, latch);
}
FrameAdapter(Http2Connection connection, Http2FrameListener listener, CountDownLatch latch) {
this(connection, new DefaultHttp2FrameReader(false), listener, latch);
}
FrameAdapter(Http2Connection connection, DefaultHttp2FrameReader reader, Http2FrameListener listener,
CountDownLatch latch) {
this.connection = connection;
this.listener = listener;
this.reader = reader;
this.latch = latch;
}
private Http2Stream getOrCreateStream(int streamId, boolean halfClosed) throws Http2Exception {
return getOrCreateStream(connection, streamId, halfClosed);
}
public static Http2Stream getOrCreateStream(Http2Connection connection, int streamId, boolean halfClosed)
throws Http2Exception {
if (connection != null) {
Http2Stream stream = connection.stream(streamId);
if (stream == null) {
if (connection.isServer() && streamId % 2 == 0 || !connection.isServer() && streamId % 2 != 0) {
stream = connection.local().createStream(streamId, halfClosed);
} else {
stream = connection.remote().createStream(streamId, halfClosed);
}
}
return stream;
}
return null;
}
private void closeStream(Http2Stream stream) {
closeStream(stream, false);
}
protected void closeStream(Http2Stream stream, boolean dataRead) {
if (stream != null) {
stream.close();
}
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
reader.readFrame(ctx, in, new Http2FrameListener() {
@Override
public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding,
boolean endOfStream) throws Http2Exception {
Http2Stream stream = getOrCreateStream(streamId, endOfStream);
int processed = listener.onDataRead(ctx, streamId, data, padding, endOfStream);
if (endOfStream) {
closeStream(stream, true);
}
latch.countDown();
return processed;
}
@Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endStream) throws Http2Exception {
Http2Stream stream = getOrCreateStream(streamId, endStream);
listener.onHeadersRead(ctx, streamId, headers, padding, endStream);
if (endStream) {
closeStream(stream);
}
latch.countDown();
}
@Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers,
int streamDependency, short weight, boolean exclusive, int padding, boolean endStream)
throws Http2Exception {
Http2Stream stream = getOrCreateStream(streamId, endStream);
listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding,
endStream);
if (endStream) {
closeStream(stream);
}
latch.countDown();
}
@Override
public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight,
boolean exclusive) throws Http2Exception {
listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive);
latch.countDown();
}
@Override
public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode)
throws Http2Exception {
Http2Stream stream = getOrCreateStream(streamId, false);
listener.onRstStreamRead(ctx, streamId, errorCode);
closeStream(stream);
latch.countDown();
}
@Override
public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception {
listener.onSettingsAckRead(ctx);
latch.countDown();
}
@Override
public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception {
listener.onSettingsRead(ctx, settings);
latch.countDown();
}
@Override
public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingRead(ctx, data);
latch.countDown();
}
@Override
public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingAckRead(ctx, data);
latch.countDown();
}
@Override
public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId,
Http2Headers headers, int padding) throws Http2Exception {
getOrCreateStream(promisedStreamId, false);
listener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding);
latch.countDown();
}
@Override
public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception {
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
latch.countDown();
}
@Override
public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement)
throws Http2Exception {
getOrCreateStream(streamId, false);
listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement);
latch.countDown();
}
@Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) throws Http2Exception {
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
latch.countDown();
}
});
}
}
/**
* A decorator around a {@link Http2FrameListener} that counts down the latch so that we can await the completion of
* the request.
*/
static class FrameCountDown implements Http2FrameListener {
private final Http2FrameListener listener;
private final CountDownLatch messageLatch;
private final CountDownLatch settingsAckLatch;
private final CountDownLatch dataLatch;
private final CountDownLatch trailersLatch;
private final CountDownLatch goAwayLatch;
FrameCountDown(Http2FrameListener listener, CountDownLatch settingsAckLatch, CountDownLatch messageLatch) {
this(listener, settingsAckLatch, messageLatch, null, null);
}
FrameCountDown(Http2FrameListener listener, CountDownLatch settingsAckLatch, CountDownLatch messageLatch,
CountDownLatch dataLatch, CountDownLatch trailersLatch) {
this(listener, settingsAckLatch, messageLatch, dataLatch, trailersLatch, messageLatch);
}
FrameCountDown(Http2FrameListener listener, CountDownLatch settingsAckLatch, CountDownLatch messageLatch,
CountDownLatch dataLatch, CountDownLatch trailersLatch, CountDownLatch goAwayLatch) {
this.listener = listener;
this.messageLatch = messageLatch;
this.settingsAckLatch = settingsAckLatch;
this.dataLatch = dataLatch;
this.trailersLatch = trailersLatch;
this.goAwayLatch = goAwayLatch;
}
@Override
public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream)
throws Http2Exception {
int numBytes = data.readableBytes();
int processed = listener.onDataRead(ctx, streamId, data, padding, endOfStream);
messageLatch.countDown();
if (dataLatch != null) {
for (int i = 0; i < numBytes; ++i) {
dataLatch.countDown();
}
}
return processed;
}
@Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endStream) throws Http2Exception {
listener.onHeadersRead(ctx, streamId, headers, padding, endStream);
messageLatch.countDown();
if (trailersLatch != null && endStream) {
trailersLatch.countDown();
}
}
@Override
public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency,
short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception {
listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream);
messageLatch.countDown();
if (trailersLatch != null && endStream) {
trailersLatch.countDown();
}
}
@Override
public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight,
boolean exclusive) throws Http2Exception {
listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive);
messageLatch.countDown();
}
@Override
public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception {
listener.onRstStreamRead(ctx, streamId, errorCode);
messageLatch.countDown();
}
@Override
public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception {
listener.onSettingsAckRead(ctx);
settingsAckLatch.countDown();
}
@Override
public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception {
listener.onSettingsRead(ctx, settings);
messageLatch.countDown();
}
@Override
public void onPingRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingRead(ctx, data);
messageLatch.countDown();
}
@Override
public void onPingAckRead(ChannelHandlerContext ctx, ByteBuf data) throws Http2Exception {
listener.onPingAckRead(ctx, data);
messageLatch.countDown();
}
@Override
public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId,
Http2Headers headers, int padding) throws Http2Exception {
listener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding);
messageLatch.countDown();
}
@Override
public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData)
throws Http2Exception {
listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData);
goAwayLatch.countDown();
}
@Override
public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement)
throws Http2Exception {
listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement);
messageLatch.countDown();
}
@Override
public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags,
ByteBuf payload) throws Http2Exception {
listener.onUnknownFrame(ctx, frameType, streamId, flags, payload);
messageLatch.countDown();
}
}
static ChannelPromise newVoidPromise(final Channel channel) {
return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE) {
@Override
public ChannelPromise addListener(
GenericFutureListener<? extends Future<? super Void>> listener) {
throw new AssertionFailedError();
}
@Override
public ChannelPromise addListeners(
GenericFutureListener<? extends Future<? super Void>>... listeners) {
throw new AssertionFailedError();
}
@Override
public boolean isVoid() {
return true;
}
@Override
public boolean tryFailure(Throwable cause) {
channel().pipeline().fireExceptionCaught(cause);
return true;
}
@Override
public ChannelPromise setFailure(Throwable cause) {
tryFailure(cause);
return this;
}
@Override
public ChannelPromise unvoid() {
ChannelPromise promise =
new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE);
promise.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) {
channel().pipeline().fireExceptionCaught(future.cause());
}
}
});
return promise;
}
};
}
}