package test.r2.integ;
import com.linkedin.common.callback.Callback;
import com.linkedin.common.callback.FutureCallback;
import com.linkedin.common.util.None;
import com.linkedin.data.ByteString;
import com.linkedin.r2.filter.CompressionConfig;
import com.linkedin.r2.filter.FilterChains;
import com.linkedin.r2.filter.compression.ClientStreamCompressionFilter;
import com.linkedin.r2.filter.compression.streaming.StreamEncodingType;
import com.linkedin.r2.filter.compression.streaming.Bzip2Compressor;
import com.linkedin.r2.filter.compression.streaming.DeflateCompressor;
import com.linkedin.r2.filter.compression.streaming.GzipCompressor;
import com.linkedin.r2.filter.compression.streaming.SnappyCompressor;
import com.linkedin.r2.filter.compression.streaming.StreamingCompressor;
import com.linkedin.r2.filter.message.stream.StreamFilter;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.Messages;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.rest.RestStatus;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamRequestBuilder;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.stream.entitystream.EntityStream;
import com.linkedin.r2.message.stream.entitystream.EntityStreams;
import com.linkedin.r2.message.stream.entitystream.ReadHandle;
import com.linkedin.r2.message.stream.entitystream.Reader;
import com.linkedin.r2.sample.Bootstrap;
import com.linkedin.r2.transport.common.Client;
import com.linkedin.r2.transport.common.StreamRequestHandler;
import com.linkedin.r2.transport.common.TransportClientFactory;
import com.linkedin.r2.transport.common.bridge.client.TransportClientAdapter;
import com.linkedin.r2.transport.common.bridge.server.TransportDispatcher;
import com.linkedin.r2.transport.common.bridge.server.TransportDispatcherBuilder;
import com.linkedin.r2.transport.http.client.HttpClientFactory;
import com.linkedin.r2.transport.http.common.HttpProtocolVersion;
import com.linkedin.r2.transport.http.server.HttpJettyServer;
import com.linkedin.r2.transport.http.server.HttpServer;
import com.linkedin.r2.transport.http.server.HttpServerFactory;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
/**
* @author Ang Xu
*/
public class TestRequestCompression
{
private static final URI GZIP_URI = URI.create("/" + StreamEncodingType.GZIP.getHttpName());
private static final URI DEFLATE_URI = URI.create("/" + StreamEncodingType.DEFLATE.getHttpName());
private static final URI BZIP2_URI = URI.create("/" + StreamEncodingType.BZIP2.getHttpName());
private static final URI SNAPPY_URI = URI.create("/" + StreamEncodingType.SNAPPY_FRAMED.getHttpName());
private static final URI NO_COMPRESSION_URI = URI.create("/noCompression");
private static final int PORT = 11940;
private static final byte BYTE = 50;
private static final int THRESHOLD = 4096;
private static final int NUM_BYTES = 1024 * 1024 * 16;
private ExecutorService _executor = Executors.newCachedThreadPool();
private HttpServer _server;
private List<TransportClientFactory> _clientFactories = new ArrayList<TransportClientFactory>();
private List<Client> _clients = new ArrayList<Client>();
@BeforeClass
public void setup() throws IOException
{
_server = new HttpServerFactory(HttpJettyServer.ServletType.ASYNC_EVENT).createH2cServer(PORT,
getTransportDispatcher(), true);
_server.start();
}
protected TransportDispatcher getTransportDispatcher()
{
return new TransportDispatcherBuilder()
.addStreamHandler(GZIP_URI, new StreamCompressionHandler(new GzipCompressor(_executor)))
.addStreamHandler(DEFLATE_URI, new StreamCompressionHandler(new DeflateCompressor(_executor)))
.addStreamHandler(BZIP2_URI, new StreamCompressionHandler(new Bzip2Compressor(_executor)))
.addStreamHandler(SNAPPY_URI, new StreamCompressionHandler(new SnappyCompressor(_executor)))
.addStreamHandler(NO_COMPRESSION_URI, new NoCompressionHandler())
.build();
}
@DataProvider
public Object[][] requestCompressionData()
{
StreamEncodingType[] encodings =
new StreamEncodingType[]{
StreamEncodingType.GZIP,
StreamEncodingType.DEFLATE,
StreamEncodingType.SNAPPY_FRAMED,
StreamEncodingType.BZIP2,
};
String[] protocols = new String[] {
HttpProtocolVersion.HTTP_1_1.name(),
HttpProtocolVersion.HTTP_2.name(),
};
Object[][] args = new Object[encodings.length * protocols.length][2];
int cur = 0;
for (StreamEncodingType requestEncoding : encodings)
{
for (String protocol : protocols)
{
StreamFilter clientCompressionFilter =
new ClientStreamCompressionFilter(requestEncoding, new CompressionConfig(THRESHOLD), null, new CompressionConfig(THRESHOLD),
Arrays.asList(new String[]{"*"}), _executor);
TransportClientFactory factory =
new HttpClientFactory.Builder().setFilterChain(FilterChains.createStreamChain(clientCompressionFilter))
.build();
HashMap<String, String> properties = new HashMap<>();
properties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, protocol);
Client client = new TransportClientAdapter(factory.getClient(properties), true);
args[cur][0] = client;
args[cur][1] = URI.create("/" + requestEncoding.getHttpName());
cur++;
_clientFactories.add(factory);
_clients.add(client);
}
}
return args;
}
@DataProvider
public Object[][] noCompressionData()
{
StreamEncodingType[] encodings =
new StreamEncodingType[]{
StreamEncodingType.GZIP,
StreamEncodingType.DEFLATE,
StreamEncodingType.SNAPPY_FRAMED,
StreamEncodingType.BZIP2,
StreamEncodingType.IDENTITY
};
String[] protocols = new String[] {
HttpProtocolVersion.HTTP_1_1.name(),
HttpProtocolVersion.HTTP_2.name(),
};
Object[][] args = new Object[encodings.length * protocols.length][1];
int cur = 0;
for (StreamEncodingType requestEncoding : encodings)
{
for (String protocol : protocols)
{
StreamFilter clientCompressionFilter =
new ClientStreamCompressionFilter(requestEncoding, new CompressionConfig(THRESHOLD), null, new CompressionConfig(THRESHOLD),
Arrays.asList(new String[]{"*"}), _executor);
TransportClientFactory factory =
new HttpClientFactory.Builder().setFilterChain(FilterChains.createStreamChain(clientCompressionFilter))
.build();
HashMap<String, String> properties = new HashMap<>();
properties.put(HttpClientFactory.HTTP_PROTOCOL_VERSION, protocol);
Client client = new TransportClientAdapter(factory.getClient(Collections.<String, String>emptyMap()), true);
args[cur][0] = client;
//args[cur][1] = URI.create("/" + requestEncoding.getHttpName());
cur++;
_clientFactories.add(factory);
_clients.add(client);
}
}
return args;
}
@Test(dataProvider = "noCompressionData")
public void testNoCompression(Client client)
throws InterruptedException, TimeoutException, ExecutionException
{
StreamRequestBuilder builder = new StreamRequestBuilder((Bootstrap.createHttpURI(PORT, NO_COMPRESSION_URI)));
BytesWriter writer = new BytesWriter(THRESHOLD-1, BYTE);
StreamRequest request = builder.build(EntityStreams.newEntityStream(writer));
final FutureCallback<StreamResponse> callback = new FutureCallback<StreamResponse>();
client.streamRequest(request, callback);
final StreamResponse response = callback.get(60, TimeUnit.SECONDS);
Assert.assertEquals(response.getStatus(), RestStatus.OK);
}
@Test(dataProvider = "requestCompressionData")
public void testRequestCompression(Client client, URI uri)
throws InterruptedException, TimeoutException, ExecutionException
{
StreamRequestBuilder builder = new StreamRequestBuilder((Bootstrap.createHttpURI(PORT, uri)));
BytesWriter writer = new BytesWriter(NUM_BYTES, BYTE);
StreamRequest request = builder.build(EntityStreams.newEntityStream(writer));
final FutureCallback<StreamResponse> callback = new FutureCallback<StreamResponse>();
client.streamRequest(request, callback);
final StreamResponse response = callback.get(60, TimeUnit.SECONDS);
Assert.assertEquals(response.getStatus(), RestStatus.OK);
}
private static class NoCompressionHandler implements StreamRequestHandler
{
@Override
public void handleRequest(StreamRequest request, RequestContext requestContext, final Callback<StreamResponse> callback)
{
final ByteChecker reader = new ByteChecker(BYTE, THRESHOLD-1, new Callback<None>()
{
@Override
public void onError(Throwable e)
{
callback.onError(e);
}
@Override
public void onSuccess(None result)
{
RestResponse response = RestStatus.responseForStatus(RestStatus.OK, "");
callback.onSuccess(Messages.toStreamResponse(response));
}
});
request.getEntityStream().setReader(reader);
}
}
private static class StreamCompressionHandler implements StreamRequestHandler
{
private final StreamingCompressor _compressor;
public StreamCompressionHandler(StreamingCompressor compressor)
{
_compressor = compressor;
}
@Override
public void handleRequest(StreamRequest request, RequestContext requestContext, final Callback<StreamResponse> callback)
{
EntityStream uncompressedStream = _compressor.inflate(request.getEntityStream());
final ByteChecker reader = new ByteChecker(BYTE, NUM_BYTES, new Callback<None>()
{
@Override
public void onError(Throwable e)
{
callback.onError(e);
}
@Override
public void onSuccess(None result)
{
RestResponse response = RestStatus.responseForStatus(RestStatus.OK, "");
callback.onSuccess(Messages.toStreamResponse(response));
}
});
uncompressedStream.setReader(reader);
}
}
private static class ByteChecker implements Reader
{
private final byte _expectedByte;
private final int _expectedLength;
private final Callback<None> _callback;
private ReadHandle _rh;
private int _length;
private boolean _bytesCorrect;
public ByteChecker(byte b, int len, Callback<None> callback)
{
_expectedByte = b;
_expectedLength = len;
_callback = callback;
_bytesCorrect = true;
}
@Override
public void onInit(ReadHandle rh)
{
_rh = rh;
_rh.request(Integer.MAX_VALUE);
}
@Override
public void onDataAvailable(ByteString data)
{
_length += data.length();
byte [] bytes = data.copyBytes();
for (byte b : bytes)
{
if (b != _expectedByte)
{
_bytesCorrect = false;
}
}
_rh.request(1);
}
@Override
public void onDone()
{
if (_bytesCorrect && _length == _expectedLength)
{
_callback.onSuccess(None.none());
}
else
{
_callback.onError(new IllegalArgumentException("input stream is incorrect"));
}
}
@Override
public void onError(Throwable e)
{
_callback.onError(e);
}
}
}