package test.r2.integ;
import com.linkedin.common.callback.Callback;
import com.linkedin.common.util.None;
import com.linkedin.data.ByteString;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.Messages;
import com.linkedin.r2.message.rest.RestException;
import com.linkedin.r2.message.rest.RestResponse;
import com.linkedin.r2.message.rest.RestStatus;
import com.linkedin.r2.message.stream.StreamException;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamRequestBuilder;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.stream.entitystream.EntityStream;
import com.linkedin.r2.message.stream.entitystream.EntityStreams;
import com.linkedin.r2.message.stream.entitystream.ReadHandle;
import com.linkedin.r2.message.stream.entitystream.Reader;
import com.linkedin.r2.message.stream.entitystream.WriteHandle;
import com.linkedin.r2.sample.Bootstrap;
import com.linkedin.r2.transport.common.Client;
import com.linkedin.r2.transport.common.StreamRequestHandler;
import com.linkedin.r2.transport.common.bridge.server.TransportDispatcher;
import com.linkedin.r2.transport.common.bridge.server.TransportDispatcherBuilder;
import com.linkedin.r2.transport.http.client.HttpClientFactory;
import com.linkedin.r2.transport.http.server.HttpJettyServer;
import com.linkedin.r2.transport.http.server.HttpServerFactory;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Factory;
import org.testng.annotations.Test;
import java.net.URI;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
/**
* This class tests client sending streaming request && server receiving streaming request
*
* @author Zhenkai Zhu
*/
public class TestStreamRequest extends AbstractStreamTest
{
private static final URI LARGE_URI = URI.create("/large");
private static final URI FOOBAR_URI = URI.create("/foobar");
private static final URI RATE_LIMITED_URI = URI.create("/rated-limited");
private static final URI ERROR_RECEIVER_URI = URI.create("/error-receiver");
private CheckRequestHandler _checkRequestHandler;
private RateLimitedRequestHandler _rateLimitedRequestHandler;
private final HttpJettyServer.ServletType _servletType;
@Factory(dataProvider = "configs")
public TestStreamRequest(HttpJettyServer.ServletType servletType)
{
_servletType = servletType;
}
@DataProvider
public static Object[][] configs()
{
return new Object[][] {{HttpJettyServer.ServletType.RAP}, {HttpJettyServer.ServletType.ASYNC_EVENT}};
}
@Override
protected HttpServerFactory getServerFactory()
{
return new HttpServerFactory(_servletType);
}
@Override
protected TransportDispatcher getTransportDispatcher()
{
_scheduler = Executors.newSingleThreadScheduledExecutor();
_checkRequestHandler = new CheckRequestHandler(BYTE);
_rateLimitedRequestHandler = new RateLimitedRequestHandler(_scheduler, INTERVAL, BYTE);
return new TransportDispatcherBuilder()
.addStreamHandler(LARGE_URI, _checkRequestHandler)
.addStreamHandler(FOOBAR_URI, new CheckRequestHandler(BYTE))
.addStreamHandler(RATE_LIMITED_URI, _rateLimitedRequestHandler)
.addStreamHandler(ERROR_RECEIVER_URI, new ThrowWhenReceivingRequestHandler())
.build();
}
@Override
protected Map<String, String> getHttp1ClientProperties()
{
Map<String, String> clientProperties = super.getHttp1ClientProperties();
clientProperties.put(HttpClientFactory.HTTP_REQUEST_TIMEOUT, "30000");
return clientProperties;
}
@Override
protected Map<String, String> getHttp2ClientProperties()
{
Map<String, String> clientProperties = super.getHttp2ClientProperties();
clientProperties.put(HttpClientFactory.HTTP_REQUEST_TIMEOUT, "30000");
return clientProperties;
}
@Test
public void testRequestLarge() throws Exception
{
for (Client client : clients())
{
final long totalBytes = LARGE_BYTES_NUM;
EntityStream entityStream = EntityStreams.newEntityStream(new BytesWriter(totalBytes, BYTE));
StreamRequestBuilder builder = new StreamRequestBuilder(Bootstrap.createHttpURI(PORT, LARGE_URI));
StreamRequest request = builder.setMethod("POST").build(entityStream);
final AtomicInteger status = new AtomicInteger(-1);
final CountDownLatch latch = new CountDownLatch(1);
Callback<StreamResponse> callback = expectSuccessCallback(latch, status);
client.streamRequest(request, callback);
latch.await(60000, TimeUnit.MILLISECONDS);
Assert.assertEquals(status.get(), RestStatus.OK);
BytesReader reader = _checkRequestHandler.getReader();
Assert.assertNotNull(reader);
Assert.assertEquals(totalBytes, reader.getTotalBytes());
Assert.assertTrue(reader.allBytesCorrect());
}
}
// jetty 404 tests singled out
@Test(enabled = false)
public void test404() throws Exception
{
for (Client client : clients())
{
final long totalBytes = TINY_BYTES_NUM;
EntityStream entityStream = EntityStreams.newEntityStream(new BytesWriter(totalBytes, BYTE));
StreamRequestBuilder builder = new StreamRequestBuilder(Bootstrap.createHttpURI(PORT, URI.create("/boo")));
StreamRequest request = builder.setMethod("POST").build(entityStream);
final AtomicInteger status = new AtomicInteger(-1);
final CountDownLatch latch = new CountDownLatch(1);
Callback<StreamResponse> callback = expectErrorCallback(latch, status);
client.streamRequest(request, callback);
latch.await(60000, TimeUnit.MILLISECONDS);
Assert.assertEquals(status.get(), 404);
}
}
@Test
public void testErrorWriter() throws Exception
{
for (Client client : clients())
{
final long totalBytes = SMALL_BYTES_NUM;
EntityStream entityStream = EntityStreams.newEntityStream(new ErrorWriter(totalBytes, BYTE));
StreamRequestBuilder builder = new StreamRequestBuilder(Bootstrap.createHttpURI(PORT, FOOBAR_URI));
StreamRequest request = builder.setMethod("POST").build(entityStream);
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Throwable> error = new AtomicReference<Throwable>();
Callback<StreamResponse> callback = new Callback<StreamResponse>()
{
@Override
public void onError(Throwable e)
{
error.set(e);
latch.countDown();
}
@Override
public void onSuccess(StreamResponse result)
{
latch.countDown();
}
};
client.streamRequest(request, callback);
latch.await();
Assert.assertNotNull(error.get());
}
}
@Test
public void testErrorReceiver() throws Exception
{
for (Client client : clients())
{
final long totalBytes = SMALL_BYTES_NUM;
EntityStream entityStream = EntityStreams.newEntityStream(new BytesWriter(totalBytes, BYTE));
StreamRequestBuilder builder = new StreamRequestBuilder(Bootstrap.createHttpURI(PORT, ERROR_RECEIVER_URI));
StreamRequest request = builder.setMethod("POST").build(entityStream);
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Throwable> error = new AtomicReference<Throwable>();
Callback<StreamResponse> callback = new Callback<StreamResponse>()
{
@Override
public void onError(Throwable e)
{
error.set(e);
latch.countDown();
}
@Override
public void onSuccess(StreamResponse result)
{
latch.countDown();
}
};
client.streamRequest(request, callback);
latch.await();
Assert.assertNotNull(error.get());
}
}
@Test
public void testBackPressure() throws Exception
{
for (Client client : clients())
{
final long totalBytes = SMALL_BYTES_NUM;
TimedBytesWriter writer = new TimedBytesWriter(totalBytes, BYTE);
EntityStream entityStream = EntityStreams.newEntityStream(writer);
StreamRequestBuilder builder = new StreamRequestBuilder(Bootstrap.createHttpURI(PORT, RATE_LIMITED_URI));
StreamRequest request = builder.setMethod("POST").build(entityStream);
final AtomicInteger status = new AtomicInteger(-1);
final CountDownLatch latch = new CountDownLatch(1);
Callback<StreamResponse> callback = expectSuccessCallback(latch, status);
client.streamRequest(request, callback);
latch.await(60000, TimeUnit.MILLISECONDS);
Assert.assertEquals(status.get(), RestStatus.OK);
TimedBytesReader reader = _rateLimitedRequestHandler.getReader();
Assert.assertNotNull(reader);
Assert.assertEquals(totalBytes, reader.getTotalBytes());
Assert.assertTrue(reader.allBytesCorrect());
long clientSendTimespan = writer.getStopTime() - writer.getStartTime();
long serverReceiveTimespan = reader.getStopTime() - reader.getStartTime();
Assert.assertTrue(serverReceiveTimespan > 1000);
double diff = Math.abs(serverReceiveTimespan - clientSendTimespan);
double diffRatio = diff / clientSendTimespan;
// make it generous to reduce the chance occasional test failures
Assert.assertTrue(diffRatio < 0.2);
}
}
private static class CheckRequestHandler implements StreamRequestHandler
{
private final byte _b;
private TimedBytesReader _reader;
CheckRequestHandler(byte b)
{
_b = b;
}
@Override
public void handleRequest(StreamRequest request, RequestContext requestContext, final Callback<StreamResponse> callback)
{
Callback<None> readerCallback = new Callback<None>()
{
@Override
public void onError(Throwable e)
{
RestException restException = new RestException(RestStatus.responseForError(500, e));
callback.onError(restException);
}
@Override
public void onSuccess(None result)
{
RestResponse response = RestStatus.responseForStatus(RestStatus.OK, "");
callback.onSuccess(Messages.toStreamResponse(response));
}
};
_reader = createReader(_b, readerCallback);
request.getEntityStream().setReader(_reader);
}
TimedBytesReader getReader()
{
return _reader;
}
protected TimedBytesReader createReader(byte b, Callback<None> readerCallback)
{
return new TimedBytesReader(_b, readerCallback);
}
}
private static class RateLimitedRequestHandler extends CheckRequestHandler
{
private final ScheduledExecutorService _scheduler;
private final long _interval;
RateLimitedRequestHandler(ScheduledExecutorService scheduler, long interval, byte b)
{
super((b));
_scheduler = scheduler;
_interval = interval;
}
@Override
protected TimedBytesReader createReader(byte b, Callback<None> readerCallback)
{
return new TimedBytesReader(b, readerCallback)
{
int count = 0;
@Override
public void requestMore(final ReadHandle rh)
{
count++;
if (count % 16 == 0)
{
_scheduler.schedule(new Runnable()
{
@Override
public void run()
{
rh.request(1);
}
},_interval, TimeUnit.MILLISECONDS);
}
else
{
rh.request(1);
}
}
};
}
}
private static class ThrowWhenReceivingRequestHandler implements StreamRequestHandler
{
@Override
public void handleRequest(StreamRequest request, RequestContext requestContext, final Callback<StreamResponse> callback)
{
request.getEntityStream().setReader(new Reader()
{
ReadHandle _rh;
@Override
public void onInit(ReadHandle rh)
{
_rh = rh;
_rh.request(10);
}
@Override
public void onDataAvailable(ByteString data)
{
throw new RuntimeException("some exception throw due to bug");
}
@Override
public void onDone()
{
}
@Override
public void onError(Throwable e)
{
}
});
}
}
private static Callback<StreamResponse> expectErrorCallback(final CountDownLatch latch, final AtomicInteger status)
{
return new Callback<StreamResponse>()
{
@Override
public void onError(Throwable e)
{
if (e instanceof StreamException)
{
StreamResponse errorResponse = ((StreamException) e).getResponse();
status.set(errorResponse.getStatus());
}
latch.countDown();
}
@Override
public void onSuccess(StreamResponse result)
{
latch.countDown();
throw new RuntimeException("Should have failed with 404");
}
};
}
private static Callback<StreamResponse> expectSuccessCallback(final CountDownLatch latch, final AtomicInteger status)
{
return new Callback<StreamResponse>()
{
@Override
public void onError(Throwable e)
{
latch.countDown();
}
@Override
public void onSuccess(StreamResponse result)
{
status.set(result.getStatus());
latch.countDown();
}
};
}
private static class ErrorWriter extends TimedBytesWriter
{
private long _total;
ErrorWriter(long total, byte fill)
{
super(total * 2, fill);
_total = total;
}
@Override
protected void afterWrite(WriteHandle wh, long written)
{
if (written > _total)
{
_total = _total * 2;
wh.error(new RuntimeException("Error for testing"));
}
}
}
}