package test.r2.integ; import com.linkedin.common.callback.Callback; import com.linkedin.common.callback.FutureCallback; import com.linkedin.common.util.None; import com.linkedin.r2.message.RequestContext; import com.linkedin.r2.message.rest.RestStatus; import com.linkedin.r2.message.stream.StreamRequest; import com.linkedin.r2.message.stream.StreamRequestBuilder; import com.linkedin.r2.message.stream.StreamResponse; import com.linkedin.r2.message.stream.StreamResponseBuilder; import com.linkedin.r2.message.stream.entitystream.DrainReader; import com.linkedin.r2.message.stream.entitystream.EntityStreams; import com.linkedin.r2.message.stream.entitystream.ReadHandle; import com.linkedin.r2.message.stream.entitystream.WriteHandle; import com.linkedin.r2.message.stream.entitystream.Writer; import com.linkedin.r2.sample.Bootstrap; import com.linkedin.r2.transport.common.Client; import com.linkedin.r2.transport.common.StreamRequestHandler; import com.linkedin.r2.transport.common.bridge.client.TransportClientAdapter; import com.linkedin.r2.transport.common.bridge.server.TransportDispatcher; import com.linkedin.r2.transport.common.bridge.server.TransportDispatcherBuilder; import com.linkedin.r2.transport.http.client.HttpClientFactory; import com.linkedin.r2.transport.http.server.HttpJettyServer; import com.linkedin.r2.transport.http.server.HttpServerFactory; import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Factory; import org.testng.annotations.Test; import java.net.URI; import java.util.HashMap; import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; /** * @author Zhenkai Zhu */ public class TestStreamResponse extends AbstractStreamTest { private static final URI LARGE_URI = URI.create("/large"); private static final URI SMALL_URI = URI.create("/small"); private static final URI SERVER_ERROR_URI = URI.create("/error"); private static final URI HICCUP_URI = URI.create("/hiccup"); private BytesWriterRequestHandler _smallHandler; private final HttpJettyServer.ServletType _servletType; @Factory(dataProvider = "configs") public TestStreamResponse(HttpJettyServer.ServletType servletType) { _servletType = servletType; } @DataProvider public static Object[][] configs() { return new Object[][] {{HttpJettyServer.ServletType.RAP}, {HttpJettyServer.ServletType.ASYNC_EVENT}}; } @Override protected HttpServerFactory getServerFactory() { return new HttpServerFactory(_servletType); } @Override protected TransportDispatcher getTransportDispatcher() { _smallHandler = new BytesWriterRequestHandler(BYTE, SMALL_BYTES_NUM); return new TransportDispatcherBuilder() .addStreamHandler(LARGE_URI, new BytesWriterRequestHandler(BYTE, LARGE_BYTES_NUM)) .addStreamHandler(SMALL_URI, _smallHandler) .addStreamHandler(SERVER_ERROR_URI, new ErrorRequestHandler(BYTE, TINY_BYTES_NUM)) .addStreamHandler(HICCUP_URI, new HiccupRequestHandler(BYTE, LARGE_BYTES_NUM, _scheduler)) .build(); } @Override protected Map<String, String> getHttp1ClientProperties() { Map<String, String> clientProperties = super.getHttp1ClientProperties(); clientProperties.put(HttpClientFactory.HTTP_MAX_RESPONSE_SIZE, String.valueOf(LARGE_BYTES_NUM * 2)); clientProperties.put(HttpClientFactory.HTTP_REQUEST_TIMEOUT, "30000"); return clientProperties; } @Override protected Map<String, String> getHttp2ClientProperties() { Map<String, String> clientProperties = super.getHttp2ClientProperties(); clientProperties.put(HttpClientFactory.HTTP_MAX_RESPONSE_SIZE, String.valueOf(LARGE_BYTES_NUM * 2)); clientProperties.put(HttpClientFactory.HTTP_REQUEST_TIMEOUT, "30000"); return clientProperties; } @Test public void testResponseLarge() throws Exception { testResponse(Bootstrap.createHttpURI(PORT, LARGE_URI)); } @Test public void testResponseHiccup() throws Exception { testResponse(Bootstrap.createHttpURI(PORT, HICCUP_URI)); } private void testResponse(URI uri) throws Exception { for (Client client : clients()) { StreamRequestBuilder builder = new StreamRequestBuilder(uri); StreamRequest request = builder.build(EntityStreams.emptyStream()); final AtomicInteger status = new AtomicInteger(-1); final CountDownLatch latch = new CountDownLatch(1); final AtomicReference<Throwable> error = new AtomicReference<Throwable>(); final Callback<None> readerCallback = getReaderCallback(latch, error); final BytesReader reader = new BytesReader(BYTE, readerCallback); Callback<StreamResponse> callback = getCallback(status, readerCallback, reader); client.streamRequest(request, callback); latch.await(60000, TimeUnit.MILLISECONDS); Assert.assertNull(error.get()); Assert.assertEquals(status.get(), RestStatus.OK); Assert.assertEquals(reader.getTotalBytes(), LARGE_BYTES_NUM); Assert.assertTrue(reader.allBytesCorrect()); } } @Test public void testErrorWhileStreaming() throws Exception { HttpClientFactory clientFactory = new HttpClientFactory(); Map<String, String> clientProperties = new HashMap<String, String>(); clientProperties.put(HttpClientFactory.HTTP_REQUEST_TIMEOUT, "1000"); Client client = new TransportClientAdapter(_clientFactory.getClient(clientProperties), true); StreamRequestBuilder builder = new StreamRequestBuilder(Bootstrap.createHttpURI(PORT, SERVER_ERROR_URI)); StreamRequest request = builder.build(EntityStreams.emptyStream()); final AtomicInteger status = new AtomicInteger(-1); final CountDownLatch latch = new CountDownLatch(1); final AtomicReference<Throwable> error = new AtomicReference<Throwable>(); final Callback<None> readerCallback = getReaderCallback(latch, error); final BytesReader reader = new BytesReader(BYTE, readerCallback); Callback<StreamResponse> callback = getCallback(status, readerCallback, reader); client.streamRequest(request, callback); latch.await(2000, TimeUnit.MILLISECONDS); Assert.assertEquals(status.get(), RestStatus.OK); Throwable throwable = error.get(); Assert.assertNotNull(throwable); final FutureCallback<None> clientShutdownCallback = new FutureCallback<None>(); client.shutdown(clientShutdownCallback); clientShutdownCallback.get(); final FutureCallback<None> factoryShutdownCallback = new FutureCallback<None>(); clientFactory.shutdown(factoryShutdownCallback); factoryShutdownCallback.get(); } @Test public void testBackpressure() throws Exception { for (Client client : clients()) { StreamRequestBuilder builder = new StreamRequestBuilder(Bootstrap.createHttpURI(PORT, SMALL_URI)); StreamRequest request = builder.build(EntityStreams.emptyStream()); final AtomicInteger status = new AtomicInteger(-1); final CountDownLatch latch = new CountDownLatch(1); final AtomicReference<Throwable> error = new AtomicReference<Throwable>(); final Callback<None> readerCallback = getReaderCallback(latch, error); final TimedBytesReader reader = new TimedBytesReader(BYTE, readerCallback) { int count = 0; @Override protected void requestMore(final ReadHandle rh) { count++; if (count % 16 == 0) { _scheduler.schedule(new Runnable() { @Override public void run() { rh.request(1); } }, INTERVAL, TimeUnit.MILLISECONDS); } else { rh.request(1); } } }; Callback<StreamResponse> callback = getCallback(status, readerCallback, reader); client.streamRequest(request, callback); latch.await(60000, TimeUnit.MILLISECONDS); Assert.assertNull(error.get()); Assert.assertEquals(status.get(), RestStatus.OK); long serverSendTimespan = _smallHandler.getWriter().getStopTime() - _smallHandler.getWriter().getStartTime(); long clientReceiveTimespan = reader.getStopTime() - reader.getStartTime(); Assert.assertTrue(clientReceiveTimespan > 1000); double diff = Math.abs(clientReceiveTimespan - serverSendTimespan); double diffRatio = diff / serverSendTimespan; // make it generous to reduce the chance occasional test failures Assert.assertTrue(diffRatio < 0.2); } } private static class BytesWriterRequestHandler implements StreamRequestHandler { private final byte _b; private final long _bytesNum; private volatile TimedBytesWriter _writer; BytesWriterRequestHandler(byte b, long bytesNUm) { _b = b; _bytesNum = bytesNUm; } @Override public void handleRequest(StreamRequest request, RequestContext requestContext, final Callback<StreamResponse> callback) { request.getEntityStream().setReader(new DrainReader()); _writer = createWriter(_bytesNum, _b); StreamResponse response = buildResponse(_writer); callback.onSuccess(response); } TimedBytesWriter getWriter() { return _writer; } protected TimedBytesWriter createWriter(long bytesNum, byte b) { return new TimedBytesWriter(_bytesNum, _b); } StreamResponse buildResponse(Writer writer) { return new StreamResponseBuilder().build(EntityStreams.newEntityStream(writer)); } } private static class ErrorWriter extends TimedBytesWriter { private long _total; ErrorWriter(long total, byte fill) { super(total * 2, fill); _total = total; } @Override protected void afterWrite(WriteHandle wh, long written) { if (written > _total) { _total = _total * 2 + 1; wh.error(new RuntimeException("Error for testing")); markError(); } } } private static class ErrorRequestHandler extends BytesWriterRequestHandler { ErrorRequestHandler(byte b, long bytesNum) { super(b, bytesNum); } @Override StreamResponse buildResponse(Writer writer) { return new StreamResponseBuilder() // set the content-length to Integer.MAX_VALUE so that receiver knows there // is an error at the end of the stream. .setHeader("Content-Length", Integer.toString(Integer.MAX_VALUE)) .build(EntityStreams.newEntityStream(writer)); } @Override protected TimedBytesWriter createWriter(long bytesNum, byte b) { return new ErrorWriter(bytesNum, b); } } private static class HiccupWriter extends TimedBytesWriter { private final Random _random = new Random(); private final ScheduledExecutorService _scheduler; HiccupWriter(long total, byte fill, ScheduledExecutorService scheduler) { super(total, fill); _scheduler = scheduler; } @Override public void onWritePossible() { if (_random.nextInt() % 17 == 0) { _scheduler.schedule(new Runnable() { @Override public void run() { HiccupWriter.super.onWritePossible(); } }, _random.nextInt() % 200, TimeUnit.MICROSECONDS); } else { super.onWritePossible(); } } } private static class HiccupRequestHandler extends BytesWriterRequestHandler { private final ScheduledExecutorService _scheduler; HiccupRequestHandler(byte b, long bytesNum, ScheduledExecutorService scheduler) { super(b, bytesNum); _scheduler = scheduler; } @Override protected TimedBytesWriter createWriter(long bytesNum, byte b) { return new HiccupWriter(bytesNum, b, _scheduler); } } private Callback<None> getReaderCallback(final CountDownLatch latch, final AtomicReference<Throwable> error) { return new Callback<None>() { @Override public void onError(Throwable e) { error.set(e); latch.countDown(); } @Override public void onSuccess(None result) { latch.countDown(); } }; } private Callback<StreamResponse> getCallback(final AtomicInteger status, final Callback<None> readerCallback, final BytesReader reader) { return new Callback<StreamResponse>() { @Override public void onError(Throwable e) { readerCallback.onError(e); } @Override public void onSuccess(StreamResponse result) { status.set(result.getStatus()); result.getEntityStream().setReader(reader); } }; } }